From d12ba2cc435acdaa1b1e73333e25ee6cb7dd5b93 Mon Sep 17 00:00:00 2001 From: Andreas Reich Date: Sun, 24 Mar 2024 17:32:24 +0100 Subject: [PATCH 01/11] basic test setup --- .../tests/compute_pass_resource_ownership.rs | 128 ++++++++++++++++++ tests/tests/root.rs | 1 + 2 files changed, 129 insertions(+) create mode 100644 tests/tests/compute_pass_resource_ownership.rs diff --git a/tests/tests/compute_pass_resource_ownership.rs b/tests/tests/compute_pass_resource_ownership.rs new file mode 100644 index 0000000000..66b962908e --- /dev/null +++ b/tests/tests/compute_pass_resource_ownership.rs @@ -0,0 +1,128 @@ +//! Tests that compute passes take ownership of resources that are passed in. +//! I.e. once a resource is passed in to a compute pass, it can be dropped. + +use std::num::NonZeroU64; + +use wgpu::util::DeviceExt as _; +use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext}; + +const SHADER_SRC: &str = " +@group(0) @binding(0) +var buffer: array; + +@compute @workgroup_size(1, 1, 1) fn main() { + buffer[0] *= 2.0; +} +"; + +#[gpu_test] +static COMPUTE_PASS_RESOURCE_OWNERSHIP: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters(TestParameters::default().test_features_limits()) + .run_async(compute_pass_resource_ownership); + +async fn compute_pass_resource_ownership(ctx: TestingContext) { + let sm = ctx + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("shader"), + source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()), + }); + + let buffer_size = 4 * std::mem::size_of::() as u64; + + let bgl = ctx + .device + .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: Some("bind_group_layout"), + entries: &[wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: NonZeroU64::new(buffer_size), + }, + count: None, + }], + }); + + let gpu_buffer = ctx + .device + .create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("gpu_buffer"), + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + contents: bytemuck::bytes_of(&[1.0_f32, 2.0, 3.0, 4.0]), + }); + + let cpu_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: Some("cpu_buffer"), + size: buffer_size, + usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ, + mapped_at_creation: false, + }); + + let indirect_buffer = ctx + .device + .create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("gpu_buffer"), + usage: wgpu::BufferUsages::INDIRECT, + contents: wgpu::util::DispatchIndirectArgs { x: 1, y: 1, z: 1 }.as_bytes(), + }); + + let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some("bind_group"), + layout: &bgl, + entries: &[wgpu::BindGroupEntry { + binding: 0, + resource: gpu_buffer.as_entire_binding(), + }], + }); + + let pipeline_layout = ctx + .device + .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("pipeline_layout"), + bind_group_layouts: &[&bgl], + push_constant_ranges: &[], + }); + + let pipeline = ctx + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("pipeline"), + layout: Some(&pipeline_layout), + module: &sm, + entry_point: "main", + }); + + let mut encoder = ctx + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("encoder"), + }); + + { + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("compute_pass"), + timestamp_writes: None, + }); + cpass.set_pipeline(&pipeline); + cpass.set_bind_group(0, &bind_group, &[]); + cpass.dispatch_workgroups_indirect(&indirect_buffer, 0); + + // TODO: Now drop all resources we set. Then do a device pool. + // TODO: Test doesn't check on timestamp writes & pipeline statistics queries yet. + } + + encoder.copy_buffer_to_buffer(&gpu_buffer, 0, &cpu_buffer, 0, buffer_size); + ctx.queue.submit([encoder.finish()]); + cpu_buffer.slice(..).map_async(wgpu::MapMode::Read, |_| ()); + ctx.async_poll(wgpu::Maintain::wait()) + .await + .panic_on_timeout(); + + let data = cpu_buffer.slice(..).get_mapped_range(); + + let floats: &[f32] = bytemuck::cast_slice(&data); + assert_eq!(floats, [2.0, 4.0, 6.0, 8.0]); +} diff --git a/tests/tests/root.rs b/tests/tests/root.rs index 6dc7af56ec..ba5e020791 100644 --- a/tests/tests/root.rs +++ b/tests/tests/root.rs @@ -11,6 +11,7 @@ mod buffer; mod buffer_copy; mod buffer_usages; mod clear_texture; +mod compute_pass_resource_ownership; mod create_surface_error; mod device; mod encoder; From a86ef5c572bdfdd91970f62865dfd13eb4be1ad5 Mon Sep 17 00:00:00 2001 From: Andreas Reich Date: Sun, 24 Mar 2024 17:36:26 +0100 Subject: [PATCH 02/11] remove lifetime and drop resources on test - test fails now just as expected --- tests/tests/compute_pass_resource_ownership.rs | 16 +++++++++++++--- wgpu/src/lib.rs | 6 +++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/tests/compute_pass_resource_ownership.rs b/tests/tests/compute_pass_resource_ownership.rs index 66b962908e..863c03b304 100644 --- a/tests/tests/compute_pass_resource_ownership.rs +++ b/tests/tests/compute_pass_resource_ownership.rs @@ -1,5 +1,9 @@ //! Tests that compute passes take ownership of resources that are passed in. //! I.e. once a resource is passed in to a compute pass, it can be dropped. +//! +//! TODO: Test doesn't check on timestamp writes & pipeline statistics queries yet. +//! (Not important as long as they are lifetime constrained to the command encoder, +//! but once we lift this constraint, we should add tests for this as well!) use std::num::NonZeroU64; @@ -93,6 +97,7 @@ async fn compute_pass_resource_ownership(ctx: TestingContext) { layout: Some(&pipeline_layout), module: &sm, entry_point: "main", + compilation_options: Default::default(), }); let mut encoder = ctx @@ -104,14 +109,19 @@ async fn compute_pass_resource_ownership(ctx: TestingContext) { { let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("compute_pass"), - timestamp_writes: None, + timestamp_writes: None, // TODO: See description above, we should test this as well once we lift the lifetime bound. }); cpass.set_pipeline(&pipeline); cpass.set_bind_group(0, &bind_group, &[]); cpass.dispatch_workgroups_indirect(&indirect_buffer, 0); - // TODO: Now drop all resources we set. Then do a device pool. - // TODO: Test doesn't check on timestamp writes & pipeline statistics queries yet. + // Now drop all resources we set. Then do a device poll to make sure the resources are really not dropped too early, no matter what. + drop(pipeline); + drop(bind_group); + drop(indirect_buffer); + ctx.async_poll(wgpu::Maintain::wait()) + .await + .panic_on_timeout(); } encoder.copy_buffer_to_buffer(&gpu_buffer, 0, &cpu_buffer, 0, buffer_size); diff --git a/wgpu/src/lib.rs b/wgpu/src/lib.rs index 2807c55cb9..2f5eed0666 100644 --- a/wgpu/src/lib.rs +++ b/wgpu/src/lib.rs @@ -4411,7 +4411,7 @@ impl<'a> ComputePass<'a> { pub fn set_bind_group( &mut self, index: u32, - bind_group: &'a BindGroup, + bind_group: &BindGroup, offsets: &[DynamicOffset], ) { DynContext::compute_pass_set_bind_group( @@ -4426,7 +4426,7 @@ impl<'a> ComputePass<'a> { } /// Sets the active compute pipeline. - pub fn set_pipeline(&mut self, pipeline: &'a ComputePipeline) { + pub fn set_pipeline(&mut self, pipeline: &ComputePipeline) { DynContext::compute_pass_set_pipeline( &*self.parent.context, &mut self.id, @@ -4484,7 +4484,7 @@ impl<'a> ComputePass<'a> { /// The structure expected in `indirect_buffer` must conform to [`DispatchIndirectArgs`](crate::util::DispatchIndirectArgs). pub fn dispatch_workgroups_indirect( &mut self, - indirect_buffer: &'a Buffer, + indirect_buffer: &Buffer, indirect_offset: BufferAddress, ) { DynContext::compute_pass_dispatch_workgroups_indirect( From e39fa3b5c75ecc9f7252e80f426a84758d3571ca Mon Sep 17 00:00:00 2001 From: Andreas Reich Date: Sun, 24 Mar 2024 18:01:28 +0100 Subject: [PATCH 03/11] compute pass recording is now hub dependent (needs gfx_select) --- wgpu-core/src/command/compute.rs | 71 +++++++++++++++++++------------- wgpu/src/backend/wgpu_core.rs | 35 ++++++++++------ 2 files changed, 66 insertions(+), 40 deletions(-) diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index 4ee48f0086..3f1d060c05 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -1,29 +1,26 @@ -use crate::command::compute_command::{ArcComputeCommand, ComputeCommand}; -use crate::device::DeviceError; -use crate::resource::Resource; -use crate::snatch::SnatchGuard; -use crate::track::TrackerIndex; use crate::{ binding_model::{ BindError, BindGroup, LateMinBufferBindingSizeMismatch, PushConstantUploadError, }, command::{ bind::Binder, + compute_command::{ArcComputeCommand, ComputeCommand}, end_pipeline_statistics_query, memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState}, BasePass, BasePassRef, BindGroupStateChange, CommandBuffer, CommandEncoderError, CommandEncoderStatus, MapPassErr, PassErrorScope, QueryUseError, StateChange, }, - device::{MissingDownlevelFlags, MissingFeatures}, + device::{DeviceError, MissingDownlevelFlags, MissingFeatures}, error::{ErrorFormatter, PrettyError}, global::Global, hal_api::HalApi, - hal_label, id, - id::DeviceId, + hal_label, + id::{self, DeviceId}, init_tracker::MemoryInitKind, - resource::{self}, + resource::{self, Resource}, + snatch::SnatchGuard, storage::Storage, - track::{Tracker, UsageConflict, UsageScope}, + track::{Tracker, TrackerIndex, UsageConflict, UsageScope}, validation::{check_buffer_usage, MissingBufferUsageError}, Label, }; @@ -35,6 +32,7 @@ use serde::Deserialize; use serde::Serialize; use thiserror::Error; +use wgt::{BufferAddress, DynamicOffset}; use std::sync::Arc; use std::{fmt, mem, str}; @@ -286,7 +284,7 @@ impl<'a, A: HalApi> State<'a, A> { } } -// Common routines between render/compute +// Running the compute pass. impl Global { pub fn command_encoder_run_compute_pass( @@ -845,13 +843,10 @@ impl Global { } } -pub mod compute_commands { - use super::{ComputeCommand, ComputePass}; - use crate::id; - use std::convert::TryInto; - use wgt::{BufferAddress, DynamicOffset}; - - pub fn wgpu_compute_pass_set_bind_group( +// Recording a compute pass. +impl Global { + pub fn compute_pass_set_bind_group( + &self, pass: &mut ComputePass, index: u32, bind_group_id: id::BindGroupId, @@ -875,7 +870,8 @@ pub mod compute_commands { }); } - pub fn wgpu_compute_pass_set_pipeline( + pub fn compute_pass_set_pipeline( + &self, pass: &mut ComputePass, pipeline_id: id::ComputePipelineId, ) { @@ -888,7 +884,12 @@ pub mod compute_commands { .push(ComputeCommand::SetPipeline(pipeline_id)); } - pub fn wgpu_compute_pass_set_push_constant(pass: &mut ComputePass, offset: u32, data: &[u8]) { + pub fn compute_pass_set_push_constant( + &self, + pass: &mut ComputePass, + offset: u32, + data: &[u8], + ) { assert_eq!( offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1), 0, @@ -915,7 +916,8 @@ pub mod compute_commands { }); } - pub fn wgpu_compute_pass_dispatch_workgroups( + pub fn compute_pass_dispatch_workgroups( + &self, pass: &mut ComputePass, groups_x: u32, groups_y: u32, @@ -926,7 +928,8 @@ pub mod compute_commands { .push(ComputeCommand::Dispatch([groups_x, groups_y, groups_z])); } - pub fn wgpu_compute_pass_dispatch_workgroups_indirect( + pub fn compute_pass_dispatch_workgroups_indirect( + &self, pass: &mut ComputePass, buffer_id: id::BufferId, offset: BufferAddress, @@ -936,7 +939,12 @@ pub mod compute_commands { .push(ComputeCommand::DispatchIndirect { buffer_id, offset }); } - pub fn wgpu_compute_pass_push_debug_group(pass: &mut ComputePass, label: &str, color: u32) { + pub fn compute_pass_push_debug_group( + &self, + pass: &mut ComputePass, + label: &str, + color: u32, + ) { let bytes = label.as_bytes(); pass.base.string_data.extend_from_slice(bytes); @@ -946,11 +954,16 @@ pub mod compute_commands { }); } - pub fn wgpu_compute_pass_pop_debug_group(pass: &mut ComputePass) { + pub fn compute_pass_pop_debug_group(&self, pass: &mut ComputePass) { pass.base.commands.push(ComputeCommand::PopDebugGroup); } - pub fn wgpu_compute_pass_insert_debug_marker(pass: &mut ComputePass, label: &str, color: u32) { + pub fn compute_pass_insert_debug_marker( + &self, + pass: &mut ComputePass, + label: &str, + color: u32, + ) { let bytes = label.as_bytes(); pass.base.string_data.extend_from_slice(bytes); @@ -960,7 +973,8 @@ pub mod compute_commands { }); } - pub fn wgpu_compute_pass_write_timestamp( + pub fn compute_pass_write_timestamp( + &self, pass: &mut ComputePass, query_set_id: id::QuerySetId, query_index: u32, @@ -971,7 +985,8 @@ pub mod compute_commands { }); } - pub fn wgpu_compute_pass_begin_pipeline_statistics_query( + pub fn compute_pass_begin_pipeline_statistics_query( + &self, pass: &mut ComputePass, query_set_id: id::QuerySetId, query_index: u32, @@ -984,7 +999,7 @@ pub mod compute_commands { }); } - pub fn wgpu_compute_pass_end_pipeline_statistics_query(pass: &mut ComputePass) { + pub fn compute_pass_end_pipeline_statistics_query(&self, pass: &mut ComputePass) { pass.base .commands .push(ComputeCommand::EndPipelineStatisticsQuery); diff --git a/wgpu/src/backend/wgpu_core.rs b/wgpu/src/backend/wgpu_core.rs index f1bdf13f0a..1253520dc7 100644 --- a/wgpu/src/backend/wgpu_core.rs +++ b/wgpu/src/backend/wgpu_core.rs @@ -23,7 +23,7 @@ use std::{ sync::Arc, }; use wgc::{ - command::{bundle_ffi::*, compute_commands::*, render_commands::*}, + command::{bundle_ffi::*, render_commands::*}, device::DeviceLostClosure, id::{CommandEncoderId, TextureViewId}, }; @@ -2311,7 +2311,8 @@ impl crate::Context for ContextWgpuCore { pipeline: &Self::ComputePipelineId, _pipeline_data: &Self::ComputePipelineData, ) { - wgpu_compute_pass_set_pipeline(pass_data, *pipeline) + let encoder = pass_data.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_set_pipeline(pass_data, *pipeline)) } fn compute_pass_set_bind_group( @@ -2323,7 +2324,8 @@ impl crate::Context for ContextWgpuCore { _bind_group_data: &Self::BindGroupData, offsets: &[wgt::DynamicOffset], ) { - wgpu_compute_pass_set_bind_group(pass_data, index, *bind_group, offsets); + let encoder = pass_data.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_set_bind_group(pass_data, index, *bind_group, offsets)); } fn compute_pass_set_push_constants( @@ -2333,7 +2335,8 @@ impl crate::Context for ContextWgpuCore { offset: u32, data: &[u8], ) { - wgpu_compute_pass_set_push_constant(pass_data, offset, data); + let encoder = pass_data.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_set_push_constant(pass_data, offset, data)); } fn compute_pass_insert_debug_marker( @@ -2342,7 +2345,8 @@ impl crate::Context for ContextWgpuCore { pass_data: &mut Self::ComputePassData, label: &str, ) { - wgpu_compute_pass_insert_debug_marker(pass_data, label, 0); + let encoder = pass_data.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_insert_debug_marker(pass_data, label, 0)); } fn compute_pass_push_debug_group( @@ -2351,7 +2355,8 @@ impl crate::Context for ContextWgpuCore { pass_data: &mut Self::ComputePassData, group_label: &str, ) { - wgpu_compute_pass_push_debug_group(pass_data, group_label, 0); + let encoder = pass_data.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_push_debug_group(pass_data, group_label, 0)); } fn compute_pass_pop_debug_group( @@ -2359,7 +2364,8 @@ impl crate::Context for ContextWgpuCore { _pass: &mut Self::ComputePassId, pass_data: &mut Self::ComputePassData, ) { - wgpu_compute_pass_pop_debug_group(pass_data); + let encoder = pass_data.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_pop_debug_group(pass_data)); } fn compute_pass_write_timestamp( @@ -2370,7 +2376,8 @@ impl crate::Context for ContextWgpuCore { _query_set_data: &Self::QuerySetData, query_index: u32, ) { - wgpu_compute_pass_write_timestamp(pass_data, *query_set, query_index) + let encoder = pass_data.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_write_timestamp(pass_data, *query_set, query_index)); } fn compute_pass_begin_pipeline_statistics_query( @@ -2381,7 +2388,8 @@ impl crate::Context for ContextWgpuCore { _query_set_data: &Self::QuerySetData, query_index: u32, ) { - wgpu_compute_pass_begin_pipeline_statistics_query(pass_data, *query_set, query_index) + let encoder = pass_data.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_begin_pipeline_statistics_query(pass_data, *query_set, query_index)); } fn compute_pass_end_pipeline_statistics_query( @@ -2389,7 +2397,8 @@ impl crate::Context for ContextWgpuCore { _pass: &mut Self::ComputePassId, pass_data: &mut Self::ComputePassData, ) { - wgpu_compute_pass_end_pipeline_statistics_query(pass_data) + let encoder = pass_data.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_end_pipeline_statistics_query(pass_data)); } fn compute_pass_dispatch_workgroups( @@ -2400,7 +2409,8 @@ impl crate::Context for ContextWgpuCore { y: u32, z: u32, ) { - wgpu_compute_pass_dispatch_workgroups(pass_data, x, y, z) + let encoder = pass_data.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_dispatch_workgroups(pass_data, x, y, z)); } fn compute_pass_dispatch_workgroups_indirect( @@ -2411,7 +2421,8 @@ impl crate::Context for ContextWgpuCore { _indirect_buffer_data: &Self::BufferData, indirect_offset: wgt::BufferAddress, ) { - wgpu_compute_pass_dispatch_workgroups_indirect(pass_data, *indirect_buffer, indirect_offset) + let encoder = pass_data.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_dispatch_workgroups_indirect(pass_data, *indirect_buffer, indirect_offset)); } fn render_bundle_encoder_set_pipeline( From 8af13dffa7b41fd0f6fb7e080d7c207fdbf3c082 Mon Sep 17 00:00:00 2001 From: Andreas Reich Date: Sat, 30 Mar 2024 13:18:47 +0100 Subject: [PATCH 04/11] compute pass recording now bumps reference count of uses resources directly on recording TODO: * bind groups don't work because the Binder gets an id only * wgpu level error handling is missing --- wgpu-core/src/command/compute.rs | 208 ++++++++++++++++++++++--------- wgpu/src/backend/wgpu_core.rs | 79 +++++++----- 2 files changed, 194 insertions(+), 93 deletions(-) diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index 3f1d060c05..d81958bf3c 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -37,23 +37,24 @@ use wgt::{BufferAddress, DynamicOffset}; use std::sync::Arc; use std::{fmt, mem, str}; -#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct ComputePass { - base: BasePass, + // TODO(#5124) / workaround for generic proliferation: + // We want to store `BasePass>` here, but this would mean + // that `ComputePass` becomes generic over HalApi meaning it would need to be identified by + // an identifier as well, causing another array on the hub and another indirection on every access. + base: Box, parent_id: id::CommandEncoderId, timestamp_writes: Option, // Resource binding dedupe state. - #[cfg_attr(feature = "serde", serde(skip))] current_bind_groups: BindGroupStateChange, - #[cfg_attr(feature = "serde", serde(skip))] current_pipeline: StateChange, } impl ComputePass { - pub fn new(parent_id: id::CommandEncoderId, desc: &ComputePassDescriptor) -> Self { + fn new(parent_id: id::CommandEncoderId, desc: &ComputePassDescriptor) -> Self { Self { - base: BasePass::new(&desc.label), + base: Box::new(BasePass::>::new(&desc.label)), parent_id, timestamp_writes: desc.timestamp_writes.cloned(), @@ -66,24 +67,22 @@ impl ComputePass { self.parent_id } - #[cfg(feature = "trace")] - pub fn into_command(self) -> crate::device::trace::Command { - crate::device::trace::Command::RunComputePass { - base: self.base, - timestamp_writes: self.timestamp_writes, - } + fn base(&self) -> &BasePass> { + self.base + .downcast_ref() + .expect("Downcast failed, unexpected backend") + } + + fn base_mut(&mut self) -> &mut BasePass> { + self.base + .downcast_mut() + .expect("Downcast failed, unexpected backend") } } impl fmt::Debug for ComputePass { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "ComputePass {{ encoder_id: {:?}, data: {:?} commands and {:?} dynamic offsets }}", - self.parent_id, - self.base.commands.len(), - self.base.dynamic_offsets.len() - ) + write!(f, "ComputePass {{ encoder_id: {:?} }}", self.parent_id,) } } @@ -287,15 +286,22 @@ impl<'a, A: HalApi> State<'a, A> { // Running the compute pass. impl Global { + pub fn command_encoder_create_compute_pass( + &self, + parent_id: id::CommandEncoderId, + desc: &ComputePassDescriptor, + ) -> ComputePass { + ComputePass::new::(parent_id, desc) + } + pub fn command_encoder_run_compute_pass( &self, encoder_id: id::CommandEncoderId, pass: &ComputePass, ) -> Result<(), ComputePassError> { - // TODO: This should go directly to `command_encoder_run_compute_pass_impl` by means of storing `ArcComputeCommand` internally. - self.command_encoder_run_compute_pass_with_unresolved_commands::( + self.command_encoder_run_compute_pass_impl::( encoder_id, - pass.base.as_ref(), + pass.base().as_ref(), pass.timestamp_writes.as_ref(), ) } @@ -851,37 +857,69 @@ impl Global { index: u32, bind_group_id: id::BindGroupId, offsets: &[DynamicOffset], - ) { + ) -> Result<(), ComputePassError> { + let base: &mut BasePass> = pass + .base + .downcast_mut() + .expect("Downcast failed, unexpected backend"); + //pass.base_mut::(); // borrow checker not happy with using this util. + let redundant = pass.current_bind_groups.set_and_check_redundant( bind_group_id, index, - &mut pass.base.dynamic_offsets, + &mut base.dynamic_offsets, offsets, ); if redundant { - return; + return Ok(()); } - pass.base.commands.push(ComputeCommand::SetBindGroup { + let hub = A::hub(self); + let bind_group = hub + .bind_groups + .read() + .get(bind_group_id) + .map_err(|_| ComputePassError { + scope: PassErrorScope::SetBindGroup(bind_group_id), + inner: ComputePassErrorInner::InvalidBindGroup(index), + })? + .clone(); + + base.commands.push(ArcComputeCommand::SetBindGroup { index, num_dynamic_offsets: offsets.len(), - bind_group_id, + bind_group, }); + + Ok(()) } pub fn compute_pass_set_pipeline( &self, pass: &mut ComputePass, pipeline_id: id::ComputePipelineId, - ) { + ) -> Result<(), ComputePassError> { if pass.current_pipeline.set_and_check_redundant(pipeline_id) { - return; + return Ok(()); } - pass.base + let hub = A::hub(self); + let pipeline = hub + .compute_pipelines + .read() + .get(pipeline_id) + .map_err(|_| ComputePassError { + scope: PassErrorScope::SetPipelineCompute(pipeline_id), + inner: ComputePassErrorInner::InvalidPipeline(pipeline_id), + })? + .clone(); + + pass.base_mut() .commands - .push(ComputeCommand::SetPipeline(pipeline_id)); + .push(ArcComputeCommand::SetPipeline(pipeline)); + + Ok(()) } pub fn compute_pass_set_push_constant( @@ -900,16 +938,17 @@ impl Global { 0, "Push constant size must be aligned to 4 bytes." ); - let value_offset = pass.base.push_constant_data.len().try_into().expect( + let base = pass.base_mut(); + let value_offset = base.push_constant_data.len().try_into().expect( "Ran out of push constant space. Don't set 4gb of push constants per ComputePass.", - ); + ); // TODO: make this an error that can be handled - pass.base.push_constant_data.extend( + base.push_constant_data.extend( data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize) .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])), ); - pass.base.commands.push(ComputeCommand::SetPushConstant { + base.commands.push(ArcComputeCommand::::SetPushConstant { offset, size_bytes: data.len() as u32, values_offset: value_offset, @@ -923,9 +962,11 @@ impl Global { groups_y: u32, groups_z: u32, ) { - pass.base + pass.base_mut() .commands - .push(ComputeCommand::Dispatch([groups_x, groups_y, groups_z])); + .push(ArcComputeCommand::::Dispatch([ + groups_x, groups_y, groups_z, + ])); } pub fn compute_pass_dispatch_workgroups_indirect( @@ -933,10 +974,26 @@ impl Global { pass: &mut ComputePass, buffer_id: id::BufferId, offset: BufferAddress, - ) { - pass.base + ) -> Result<(), ComputePassError> { + let hub = A::hub(self); + let buffer = hub + .buffers + .read() + .get(buffer_id) + .map_err(|_| ComputePassError { + scope: PassErrorScope::Dispatch { + indirect: true, + pipeline: pass.current_pipeline.last_state, + }, + inner: ComputePassErrorInner::InvalidBuffer(buffer_id), + })? + .clone(); + + pass.base_mut() .commands - .push(ComputeCommand::DispatchIndirect { buffer_id, offset }); + .push(ArcComputeCommand::::DispatchIndirect { buffer, offset }); + + Ok(()) } pub fn compute_pass_push_debug_group( @@ -946,16 +1003,19 @@ impl Global { color: u32, ) { let bytes = label.as_bytes(); - pass.base.string_data.extend_from_slice(bytes); + let base = pass.base_mut(); + base.string_data.extend_from_slice(bytes); - pass.base.commands.push(ComputeCommand::PushDebugGroup { + base.commands.push(ArcComputeCommand::::PushDebugGroup { color, len: bytes.len(), }); } pub fn compute_pass_pop_debug_group(&self, pass: &mut ComputePass) { - pass.base.commands.push(ComputeCommand::PopDebugGroup); + pass.base_mut::() + .commands + .push(ArcComputeCommand::::PopDebugGroup); } pub fn compute_pass_insert_debug_marker( @@ -965,12 +1025,14 @@ impl Global { color: u32, ) { let bytes = label.as_bytes(); - pass.base.string_data.extend_from_slice(bytes); + let base = pass.base_mut(); + base.string_data.extend_from_slice(bytes); - pass.base.commands.push(ComputeCommand::InsertDebugMarker { - color, - len: bytes.len(), - }); + base.commands + .push(ArcComputeCommand::::InsertDebugMarker { + color, + len: bytes.len(), + }); } pub fn compute_pass_write_timestamp( @@ -978,11 +1040,26 @@ impl Global { pass: &mut ComputePass, query_set_id: id::QuerySetId, query_index: u32, - ) { - pass.base.commands.push(ComputeCommand::WriteTimestamp { - query_set_id, - query_index, - }); + ) -> Result<(), ComputePassError> { + let hub = A::hub(self); + let query_set = hub + .query_sets + .read() + .get(query_set_id) + .map_err(|_| ComputePassError { + scope: PassErrorScope::WriteTimestamp, + inner: ComputePassErrorInner::InvalidQuerySet(query_set_id), + })? + .clone(); + + pass.base_mut() + .commands + .push(ArcComputeCommand::WriteTimestamp { + query_set, + query_index, + }); + + Ok(()) } pub fn compute_pass_begin_pipeline_statistics_query( @@ -990,18 +1067,31 @@ impl Global { pass: &mut ComputePass, query_set_id: id::QuerySetId, query_index: u32, - ) { - pass.base + ) -> Result<(), ComputePassError> { + let hub = A::hub(self); + let query_set = hub + .query_sets + .read() + .get(query_set_id) + .map_err(|_| ComputePassError { + scope: PassErrorScope::WriteTimestamp, + inner: ComputePassErrorInner::InvalidQuerySet(query_set_id), + })? + .clone(); + + pass.base_mut() .commands - .push(ComputeCommand::BeginPipelineStatisticsQuery { - query_set_id, + .push(ArcComputeCommand::BeginPipelineStatisticsQuery { + query_set, query_index, }); + + Ok(()) } pub fn compute_pass_end_pipeline_statistics_query(&self, pass: &mut ComputePass) { - pass.base + pass.base_mut() .commands - .push(ComputeCommand::EndPipelineStatisticsQuery); + .push(ArcComputeCommand::::EndPipelineStatisticsQuery); } } diff --git a/wgpu/src/backend/wgpu_core.rs b/wgpu/src/backend/wgpu_core.rs index 1253520dc7..537a523611 100644 --- a/wgpu/src/backend/wgpu_core.rs +++ b/wgpu/src/backend/wgpu_core.rs @@ -24,9 +24,10 @@ use std::{ }; use wgc::{ command::{bundle_ffi::*, render_commands::*}, - device::DeviceLostClosure, - id::{CommandEncoderId, TextureViewId}, + gfx_select, + id::CommandEncoderId, }; +use wgc::{device::DeviceLostClosure, id::TextureViewId}; use wgt::WasmNotSendSync; const LABEL: &str = "label"; @@ -469,6 +470,12 @@ impl Queue { } } +#[derive(Debug)] +pub struct ComputePass { + pass: wgc::command::ComputePass, + error_sink: ErrorSink, +} + #[derive(Debug)] pub struct CommandEncoder { error_sink: ErrorSink, @@ -507,7 +514,7 @@ impl crate::Context for ContextWgpuCore { type CommandEncoderId = wgc::id::CommandEncoderId; type CommandEncoderData = CommandEncoder; type ComputePassId = Unused; - type ComputePassData = wgc::command::ComputePass; + type ComputePassData = ComputePass; type RenderPassId = Unused; type RenderPassData = wgc::command::RenderPass; type CommandBufferId = wgc::id::CommandBufferId; @@ -1816,7 +1823,7 @@ impl crate::Context for ContextWgpuCore { fn command_encoder_begin_compute_pass( &self, encoder: &Self::CommandEncoderId, - _encoder_data: &Self::CommandEncoderData, + encoder_data: &Self::CommandEncoderData, desc: &ComputePassDescriptor<'_>, ) -> (Self::ComputePassId, Self::ComputePassData) { let timestamp_writes = @@ -1827,15 +1834,19 @@ impl crate::Context for ContextWgpuCore { beginning_of_pass_write_index: tw.beginning_of_pass_write_index, end_of_pass_write_index: tw.end_of_pass_write_index, }); + ( Unused, - wgc::command::ComputePass::new( - *encoder, - &wgc::command::ComputePassDescriptor { - label: desc.label.map(Borrowed), - timestamp_writes: timestamp_writes.as_ref(), - }, - ), + Self::ComputePassData { + pass: gfx_select!(encoder => self.0.command_encoder_create_compute_pass( + *encoder, + &wgc::command::ComputePassDescriptor { + label: desc.label.map(Borrowed), + timestamp_writes: timestamp_writes.as_ref(), + }, + )), + error_sink: encoder_data.error_sink.clone(), + }, ) } @@ -1847,7 +1858,7 @@ impl crate::Context for ContextWgpuCore { pass_data: &mut Self::ComputePassData, ) { if let Err(cause) = wgc::gfx_select!( - encoder => self.0.command_encoder_run_compute_pass(*encoder, pass_data) + encoder => self.0.command_encoder_run_compute_pass(*encoder, &pass_data.pass) ) { let name = wgc::gfx_select!(encoder => self.0.command_buffer_label(encoder.into_command_buffer_id())); self.handle_error( @@ -2311,8 +2322,8 @@ impl crate::Context for ContextWgpuCore { pipeline: &Self::ComputePipelineId, _pipeline_data: &Self::ComputePipelineData, ) { - let encoder = pass_data.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_set_pipeline(pass_data, *pipeline)) + let encoder = pass_data.pass.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_set_pipeline(&mut pass_data.pass, *pipeline)); } fn compute_pass_set_bind_group( @@ -2324,8 +2335,8 @@ impl crate::Context for ContextWgpuCore { _bind_group_data: &Self::BindGroupData, offsets: &[wgt::DynamicOffset], ) { - let encoder = pass_data.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_set_bind_group(pass_data, index, *bind_group, offsets)); + let encoder = pass_data.pass.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_set_bind_group(&mut pass_data.pass, index, *bind_group, offsets)); } fn compute_pass_set_push_constants( @@ -2335,8 +2346,8 @@ impl crate::Context for ContextWgpuCore { offset: u32, data: &[u8], ) { - let encoder = pass_data.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_set_push_constant(pass_data, offset, data)); + let encoder = pass_data.pass.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_set_push_constant(&mut pass_data.pass, offset, data)); } fn compute_pass_insert_debug_marker( @@ -2345,8 +2356,8 @@ impl crate::Context for ContextWgpuCore { pass_data: &mut Self::ComputePassData, label: &str, ) { - let encoder = pass_data.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_insert_debug_marker(pass_data, label, 0)); + let encoder = pass_data.pass.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_insert_debug_marker(&mut pass_data.pass, label, 0)); } fn compute_pass_push_debug_group( @@ -2355,8 +2366,8 @@ impl crate::Context for ContextWgpuCore { pass_data: &mut Self::ComputePassData, group_label: &str, ) { - let encoder = pass_data.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_push_debug_group(pass_data, group_label, 0)); + let encoder = pass_data.pass.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_push_debug_group(&mut pass_data.pass, group_label, 0)); } fn compute_pass_pop_debug_group( @@ -2364,8 +2375,8 @@ impl crate::Context for ContextWgpuCore { _pass: &mut Self::ComputePassId, pass_data: &mut Self::ComputePassData, ) { - let encoder = pass_data.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_pop_debug_group(pass_data)); + let encoder = pass_data.pass.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_pop_debug_group(&mut pass_data.pass)); } fn compute_pass_write_timestamp( @@ -2376,8 +2387,8 @@ impl crate::Context for ContextWgpuCore { _query_set_data: &Self::QuerySetData, query_index: u32, ) { - let encoder = pass_data.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_write_timestamp(pass_data, *query_set, query_index)); + let encoder = pass_data.pass.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_write_timestamp(&mut pass_data.pass, *query_set, query_index)); } fn compute_pass_begin_pipeline_statistics_query( @@ -2388,8 +2399,8 @@ impl crate::Context for ContextWgpuCore { _query_set_data: &Self::QuerySetData, query_index: u32, ) { - let encoder = pass_data.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_begin_pipeline_statistics_query(pass_data, *query_set, query_index)); + let encoder = pass_data.pass.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_begin_pipeline_statistics_query(&mut pass_data.pass, *query_set, query_index)); } fn compute_pass_end_pipeline_statistics_query( @@ -2397,8 +2408,8 @@ impl crate::Context for ContextWgpuCore { _pass: &mut Self::ComputePassId, pass_data: &mut Self::ComputePassData, ) { - let encoder = pass_data.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_end_pipeline_statistics_query(pass_data)); + let encoder = pass_data.pass.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_end_pipeline_statistics_query(&mut pass_data.pass)); } fn compute_pass_dispatch_workgroups( @@ -2409,8 +2420,8 @@ impl crate::Context for ContextWgpuCore { y: u32, z: u32, ) { - let encoder = pass_data.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_dispatch_workgroups(pass_data, x, y, z)); + let encoder = pass_data.pass.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_dispatch_workgroups(&mut pass_data.pass, x, y, z)); } fn compute_pass_dispatch_workgroups_indirect( @@ -2421,8 +2432,8 @@ impl crate::Context for ContextWgpuCore { _indirect_buffer_data: &Self::BufferData, indirect_offset: wgt::BufferAddress, ) { - let encoder = pass_data.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_dispatch_workgroups_indirect(pass_data, *indirect_buffer, indirect_offset)); + let encoder = pass_data.pass.parent_id(); + wgc::gfx_select!(encoder => self.0.compute_pass_dispatch_workgroups_indirect(&mut pass_data.pass, *indirect_buffer, indirect_offset)); } fn render_bundle_encoder_set_pipeline( From 26736285d7401934a5fdbb6f15391b6f01d57367 Mon Sep 17 00:00:00 2001 From: Andreas Reich Date: Fri, 19 Apr 2024 23:39:31 +0200 Subject: [PATCH 05/11] simplify compute pass state flush, compute pass execution no longer needs to lock bind_group storage --- wgpu-core/src/command/bind.rs | 5 ++--- wgpu-core/src/command/compute.rs | 28 +++++++--------------------- 2 files changed, 9 insertions(+), 24 deletions(-) diff --git a/wgpu-core/src/command/bind.rs b/wgpu-core/src/command/bind.rs index 7b2ac54552..c643611a96 100644 --- a/wgpu-core/src/command/bind.rs +++ b/wgpu-core/src/command/bind.rs @@ -4,7 +4,6 @@ use crate::{ binding_model::{BindGroup, LateMinBufferBindingSizeMismatch, PipelineLayout}, device::SHADER_STAGE_COUNT, hal_api::HalApi, - id::BindGroupId, pipeline::LateSizedBufferGroup, resource::Resource, }; @@ -359,11 +358,11 @@ impl Binder { &self.payloads[bind_range] } - pub(super) fn list_active(&self) -> impl Iterator + '_ { + pub(super) fn list_active<'a>(&'a self) -> impl Iterator>> + '_ { let payloads = &self.payloads; self.manager .list_active() - .map(move |index| payloads[index].group.as_ref().unwrap().as_info().id()) + .map(move |index| payloads[index].group.as_ref().unwrap()) } pub(super) fn invalid_mask(&self) -> BindGroupMask { diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index d81958bf3c..6242346ffb 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -1,7 +1,5 @@ use crate::{ - binding_model::{ - BindError, BindGroup, LateMinBufferBindingSizeMismatch, PushConstantUploadError, - }, + binding_model::{BindError, LateMinBufferBindingSizeMismatch, PushConstantUploadError}, command::{ bind::Binder, compute_command::{ArcComputeCommand, ComputeCommand}, @@ -19,7 +17,6 @@ use crate::{ init_tracker::MemoryInitKind, resource::{self, Resource}, snatch::SnatchGuard, - storage::Storage, track::{Tracker, TrackerIndex, UsageConflict, UsageScope}, validation::{check_buffer_usage, MissingBufferUsageError}, Label, @@ -250,22 +247,19 @@ impl<'a, A: HalApi> State<'a, A> { &mut self, raw_encoder: &mut A::CommandEncoder, base_trackers: &mut Tracker, - bind_group_guard: &Storage>, indirect_buffer: Option, snatch_guard: &SnatchGuard, ) -> Result<(), UsageConflict> { - for id in self.binder.list_active() { - unsafe { self.scope.merge_bind_group(&bind_group_guard[id].used)? }; + for bind_group in self.binder.list_active() { + unsafe { self.scope.merge_bind_group(&bind_group.used)? }; // Note: stateless trackers are not merged: the lifetime reference // is held to the bind group itself. } - for id in self.binder.list_active() { + for bind_group in self.binder.list_active() { unsafe { - base_trackers.set_and_remove_from_usage_scope_sparse( - &mut self.scope, - &bind_group_guard[id].used, - ) + base_trackers + .set_and_remove_from_usage_scope_sparse(&mut self.scope, &bind_group.used) } } @@ -381,7 +375,6 @@ impl Global { *status = CommandEncoderStatus::Error; let raw = encoder.open().map_pass_err(pass_scope)?; - let bind_group_guard = hub.bind_groups.read(); let query_set_guard = hub.query_sets.read(); let mut state = State { @@ -646,13 +639,7 @@ impl Global { state.is_ready().map_pass_err(scope)?; state - .flush_states( - raw, - &mut intermediate_trackers, - &*bind_group_guard, - None, - &snatch_guard, - ) + .flush_states(raw, &mut intermediate_trackers, None, &snatch_guard) .map_pass_err(scope)?; let groups_size_limit = cmd_buf.limits.max_compute_workgroups_per_dimension; @@ -725,7 +712,6 @@ impl Global { .flush_states( raw, &mut intermediate_trackers, - &*bind_group_guard, Some(buffer.as_info().tracker_index()), &snatch_guard, ) From a88a9064fb426c582dd10d5183e14b709ac4b64a Mon Sep 17 00:00:00 2001 From: Andreas Reich Date: Sat, 20 Apr 2024 00:03:56 +0200 Subject: [PATCH 06/11] wgpu sided error handling --- wgpu/src/backend/wgpu_core.rs | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/wgpu/src/backend/wgpu_core.rs b/wgpu/src/backend/wgpu_core.rs index 537a523611..67b512c51f 100644 --- a/wgpu/src/backend/wgpu_core.rs +++ b/wgpu/src/backend/wgpu_core.rs @@ -2323,7 +2323,10 @@ impl crate::Context for ContextWgpuCore { _pipeline_data: &Self::ComputePipelineData, ) { let encoder = pass_data.pass.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_set_pipeline(&mut pass_data.pass, *pipeline)); + if let Err(cause) = wgc::gfx_select!(encoder => self.0.compute_pass_set_pipeline(&mut pass_data.pass, *pipeline)) + { + self.handle_error_nolabel(&pass_data.error_sink, cause, "ComputePass::set_pipeline"); + } } fn compute_pass_set_bind_group( @@ -2336,7 +2339,10 @@ impl crate::Context for ContextWgpuCore { offsets: &[wgt::DynamicOffset], ) { let encoder = pass_data.pass.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_set_bind_group(&mut pass_data.pass, index, *bind_group, offsets)); + if let Err(cause) = wgc::gfx_select!(encoder => self.0.compute_pass_set_bind_group(&mut pass_data.pass, index, *bind_group, offsets)) + { + self.handle_error_nolabel(&pass_data.error_sink, cause, "ComputePass::set_bind_group"); + } } fn compute_pass_set_push_constants( @@ -2388,7 +2394,10 @@ impl crate::Context for ContextWgpuCore { query_index: u32, ) { let encoder = pass_data.pass.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_write_timestamp(&mut pass_data.pass, *query_set, query_index)); + if let Err(cause) = wgc::gfx_select!(encoder => self.0.compute_pass_write_timestamp(&mut pass_data.pass, *query_set, query_index)) + { + self.handle_error_nolabel(&pass_data.error_sink, cause, "ComputePass::write_timestamp"); + } } fn compute_pass_begin_pipeline_statistics_query( @@ -2400,7 +2409,14 @@ impl crate::Context for ContextWgpuCore { query_index: u32, ) { let encoder = pass_data.pass.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_begin_pipeline_statistics_query(&mut pass_data.pass, *query_set, query_index)); + if let Err(cause) = wgc::gfx_select!(encoder => self.0.compute_pass_begin_pipeline_statistics_query(&mut pass_data.pass, *query_set, query_index)) + { + self.handle_error_nolabel( + &pass_data.error_sink, + cause, + "ComputePass::begin_pipeline_statistics_query", + ); + } } fn compute_pass_end_pipeline_statistics_query( @@ -2433,7 +2449,14 @@ impl crate::Context for ContextWgpuCore { indirect_offset: wgt::BufferAddress, ) { let encoder = pass_data.pass.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_dispatch_workgroups_indirect(&mut pass_data.pass, *indirect_buffer, indirect_offset)); + if let Err(cause) = wgc::gfx_select!(encoder => self.0.compute_pass_dispatch_workgroups_indirect(&mut pass_data.pass, *indirect_buffer, indirect_offset)) + { + self.handle_error_nolabel( + &pass_data.error_sink, + cause, + "ComputePass::dispatch_workgroups_indirect", + ); + } } fn render_bundle_encoder_set_pipeline( From 841cfcb0ea77ad80bfd3b74e53b9f5a51ccde9aa Mon Sep 17 00:00:00 2001 From: Andreas Reich Date: Sat, 20 Apr 2024 12:18:49 +0200 Subject: [PATCH 07/11] make ComputePass hal dependent, removing command cast hack. Introduce DynComputePass on wgpu side --- wgpu-core/src/command/compute.rs | 139 +++++++++------------ wgpu-core/src/lib.rs | 11 ++ wgpu/src/backend/wgpu_core.rs | 205 ++++++++++++++++++++++++++----- 3 files changed, 239 insertions(+), 116 deletions(-) diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index 6242346ffb..6ed2e5675b 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -34,12 +34,8 @@ use wgt::{BufferAddress, DynamicOffset}; use std::sync::Arc; use std::{fmt, mem, str}; -pub struct ComputePass { - // TODO(#5124) / workaround for generic proliferation: - // We want to store `BasePass>` here, but this would mean - // that `ComputePass` becomes generic over HalApi meaning it would need to be identified by - // an identifier as well, causing another array on the hub and another indirection on every access. - base: Box, +pub struct ComputePass { + base: BasePass>, parent_id: id::CommandEncoderId, timestamp_writes: Option, @@ -48,10 +44,10 @@ pub struct ComputePass { current_pipeline: StateChange, } -impl ComputePass { - fn new(parent_id: id::CommandEncoderId, desc: &ComputePassDescriptor) -> Self { +impl ComputePass { + fn new(parent_id: id::CommandEncoderId, desc: &ComputePassDescriptor) -> Self { Self { - base: Box::new(BasePass::>::new(&desc.label)), + base: BasePass::>::new(&desc.label), parent_id, timestamp_writes: desc.timestamp_writes.cloned(), @@ -63,23 +59,11 @@ impl ComputePass { pub fn parent_id(&self) -> id::CommandEncoderId { self.parent_id } - - fn base(&self) -> &BasePass> { - self.base - .downcast_ref() - .expect("Downcast failed, unexpected backend") - } - - fn base_mut(&mut self) -> &mut BasePass> { - self.base - .downcast_mut() - .expect("Downcast failed, unexpected backend") - } } -impl fmt::Debug for ComputePass { +impl fmt::Debug for ComputePass { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "ComputePass {{ encoder_id: {:?} }}", self.parent_id,) + write!(f, "ComputePass {{ encoder_id: {:?} }}", self.parent_id) } } @@ -284,18 +268,17 @@ impl Global { &self, parent_id: id::CommandEncoderId, desc: &ComputePassDescriptor, - ) -> ComputePass { - ComputePass::new::(parent_id, desc) + ) -> ComputePass { + ComputePass::new(parent_id, desc) } pub fn command_encoder_run_compute_pass( &self, - encoder_id: id::CommandEncoderId, - pass: &ComputePass, + pass: &ComputePass, ) -> Result<(), ComputePassError> { - self.command_encoder_run_compute_pass_impl::( - encoder_id, - pass.base().as_ref(), + self.command_encoder_run_compute_pass_impl( + pass.parent_id, + pass.base.as_ref(), pass.timestamp_writes.as_ref(), ) } @@ -839,21 +822,15 @@ impl Global { impl Global { pub fn compute_pass_set_bind_group( &self, - pass: &mut ComputePass, + pass: &mut ComputePass, index: u32, bind_group_id: id::BindGroupId, offsets: &[DynamicOffset], ) -> Result<(), ComputePassError> { - let base: &mut BasePass> = pass - .base - .downcast_mut() - .expect("Downcast failed, unexpected backend"); - //pass.base_mut::(); // borrow checker not happy with using this util. - let redundant = pass.current_bind_groups.set_and_check_redundant( bind_group_id, index, - &mut base.dynamic_offsets, + &mut pass.base.dynamic_offsets, offsets, ); @@ -872,7 +849,7 @@ impl Global { })? .clone(); - base.commands.push(ArcComputeCommand::SetBindGroup { + pass.base.commands.push(ArcComputeCommand::SetBindGroup { index, num_dynamic_offsets: offsets.len(), bind_group, @@ -883,7 +860,7 @@ impl Global { pub fn compute_pass_set_pipeline( &self, - pass: &mut ComputePass, + pass: &mut ComputePass, pipeline_id: id::ComputePipelineId, ) -> Result<(), ComputePassError> { if pass.current_pipeline.set_and_check_redundant(pipeline_id) { @@ -901,7 +878,7 @@ impl Global { })? .clone(); - pass.base_mut() + pass.base .commands .push(ArcComputeCommand::SetPipeline(pipeline)); @@ -910,7 +887,7 @@ impl Global { pub fn compute_pass_set_push_constant( &self, - pass: &mut ComputePass, + pass: &mut ComputePass, offset: u32, data: &[u8], ) { @@ -924,40 +901,39 @@ impl Global { 0, "Push constant size must be aligned to 4 bytes." ); - let base = pass.base_mut(); - let value_offset = base.push_constant_data.len().try_into().expect( + let value_offset = pass.base.push_constant_data.len().try_into().expect( "Ran out of push constant space. Don't set 4gb of push constants per ComputePass.", ); // TODO: make this an error that can be handled - base.push_constant_data.extend( + pass.base.push_constant_data.extend( data.chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize) .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])), ); - base.commands.push(ArcComputeCommand::::SetPushConstant { - offset, - size_bytes: data.len() as u32, - values_offset: value_offset, - }); + pass.base + .commands + .push(ArcComputeCommand::::SetPushConstant { + offset, + size_bytes: data.len() as u32, + values_offset: value_offset, + }); } pub fn compute_pass_dispatch_workgroups( &self, - pass: &mut ComputePass, + pass: &mut ComputePass, groups_x: u32, groups_y: u32, groups_z: u32, ) { - pass.base_mut() - .commands - .push(ArcComputeCommand::::Dispatch([ - groups_x, groups_y, groups_z, - ])); + pass.base.commands.push(ArcComputeCommand::::Dispatch([ + groups_x, groups_y, groups_z, + ])); } pub fn compute_pass_dispatch_workgroups_indirect( &self, - pass: &mut ComputePass, + pass: &mut ComputePass, buffer_id: id::BufferId, offset: BufferAddress, ) -> Result<(), ComputePassError> { @@ -975,7 +951,7 @@ impl Global { })? .clone(); - pass.base_mut() + pass.base .commands .push(ArcComputeCommand::::DispatchIndirect { buffer, offset }); @@ -984,37 +960,38 @@ impl Global { pub fn compute_pass_push_debug_group( &self, - pass: &mut ComputePass, + pass: &mut ComputePass, label: &str, color: u32, ) { let bytes = label.as_bytes(); - let base = pass.base_mut(); - base.string_data.extend_from_slice(bytes); + pass.base.string_data.extend_from_slice(bytes); - base.commands.push(ArcComputeCommand::::PushDebugGroup { - color, - len: bytes.len(), - }); + pass.base + .commands + .push(ArcComputeCommand::::PushDebugGroup { + color, + len: bytes.len(), + }); } - pub fn compute_pass_pop_debug_group(&self, pass: &mut ComputePass) { - pass.base_mut::() + pub fn compute_pass_pop_debug_group(&self, pass: &mut ComputePass) { + pass.base .commands .push(ArcComputeCommand::::PopDebugGroup); } pub fn compute_pass_insert_debug_marker( &self, - pass: &mut ComputePass, + pass: &mut ComputePass, label: &str, color: u32, ) { let bytes = label.as_bytes(); - let base = pass.base_mut(); - base.string_data.extend_from_slice(bytes); + pass.base.string_data.extend_from_slice(bytes); - base.commands + pass.base + .commands .push(ArcComputeCommand::::InsertDebugMarker { color, len: bytes.len(), @@ -1023,7 +1000,7 @@ impl Global { pub fn compute_pass_write_timestamp( &self, - pass: &mut ComputePass, + pass: &mut ComputePass, query_set_id: id::QuerySetId, query_index: u32, ) -> Result<(), ComputePassError> { @@ -1038,19 +1015,17 @@ impl Global { })? .clone(); - pass.base_mut() - .commands - .push(ArcComputeCommand::WriteTimestamp { - query_set, - query_index, - }); + pass.base.commands.push(ArcComputeCommand::WriteTimestamp { + query_set, + query_index, + }); Ok(()) } pub fn compute_pass_begin_pipeline_statistics_query( &self, - pass: &mut ComputePass, + pass: &mut ComputePass, query_set_id: id::QuerySetId, query_index: u32, ) -> Result<(), ComputePassError> { @@ -1065,7 +1040,7 @@ impl Global { })? .clone(); - pass.base_mut() + pass.base .commands .push(ArcComputeCommand::BeginPipelineStatisticsQuery { query_set, @@ -1075,8 +1050,8 @@ impl Global { Ok(()) } - pub fn compute_pass_end_pipeline_statistics_query(&self, pass: &mut ComputePass) { - pass.base_mut() + pub fn compute_pass_end_pipeline_statistics_query(&self, pass: &mut ComputePass) { + pass.base .commands .push(ArcComputeCommand::::EndPipelineStatisticsQuery); } diff --git a/wgpu-core/src/lib.rs b/wgpu-core/src/lib.rs index 032d85a4bc..c09bb7119b 100644 --- a/wgpu-core/src/lib.rs +++ b/wgpu-core/src/lib.rs @@ -302,6 +302,17 @@ macro_rules! gfx_select { other => panic!("Unexpected backend {:?}", other), } }; + + ($id:expr => $method:ident $params:tt) => { + match $id.backend() { + wgt::Backend::Vulkan => $crate::gfx_if_vulkan!($method::<$crate::api::Vulkan> $params), + wgt::Backend::Metal => $crate::gfx_if_metal!($method::<$crate::api::Metal> $params), + wgt::Backend::Dx12 => $crate::gfx_if_dx12!($method::<$crate::api::Dx12> $params), + wgt::Backend::Gl => $crate::gfx_if_gles!($method::<$crate::api::Gles> $params), + wgt::Backend::Empty => $crate::gfx_if_empty!($method::<$crate::api::Empty> $params), + other => panic!("Unexpected backend {:?}", other), + } + }; } #[cfg(feature = "api_log_info")] diff --git a/wgpu/src/backend/wgpu_core.rs b/wgpu/src/backend/wgpu_core.rs index 67b512c51f..55469ba859 100644 --- a/wgpu/src/backend/wgpu_core.rs +++ b/wgpu/src/backend/wgpu_core.rs @@ -25,6 +25,7 @@ use std::{ use wgc::{ command::{bundle_ffi::*, render_commands::*}, gfx_select, + hal_api::HalApi, id::CommandEncoderId, }; use wgc::{device::DeviceLostClosure, id::TextureViewId}; @@ -472,10 +473,141 @@ impl Queue { #[derive(Debug)] pub struct ComputePass { - pass: wgc::command::ComputePass, + pass: Box, error_sink: ErrorSink, } +/// Trait for type erasing ComputePass. +// TODO(#5124): wgpu-core's ComputePass trait should not be hal type dependent. +// Practically speaking this allows us merge gfx_select with type erasure: +// The alternative would be to introduce ComputePassId which then first needs to be looked up and then dispatch via gfx_select. +trait DynComputePass: std::fmt::Debug + WasmNotSendSync { + fn run(&mut self, context: &wgc::global::Global) -> Result<(), wgc::command::ComputePassError>; + fn set_bind_group( + &mut self, + context: &wgc::global::Global, + index: u32, + bind_group_id: wgc::id::BindGroupId, + offsets: &[wgt::DynamicOffset], + ) -> Result<(), wgc::command::ComputePassError>; + fn set_pipeline( + &mut self, + context: &wgc::global::Global, + pipeline_id: wgc::id::ComputePipelineId, + ) -> Result<(), wgc::command::ComputePassError>; + fn set_push_constant(&mut self, context: &wgc::global::Global, offset: u32, data: &[u8]); + fn dispatch_workgroups( + &mut self, + context: &wgc::global::Global, + groups_x: u32, + groups_y: u32, + groups_z: u32, + ); + fn dispatch_workgroups_indirect( + &mut self, + context: &wgc::global::Global, + buffer_id: wgc::id::BufferId, + offset: wgt::BufferAddress, + ) -> Result<(), wgc::command::ComputePassError>; + fn push_debug_group(&mut self, context: &wgc::global::Global, label: &str, color: u32); + fn pop_debug_group(&mut self, context: &wgc::global::Global); + fn insert_debug_marker(&mut self, context: &wgc::global::Global, label: &str, color: u32); + fn write_timestamp( + &mut self, + context: &wgc::global::Global, + query_set_id: wgc::id::QuerySetId, + query_index: u32, + ) -> Result<(), wgc::command::ComputePassError>; + fn begin_pipeline_statistics_query( + &mut self, + context: &wgc::global::Global, + query_set_id: wgc::id::QuerySetId, + query_index: u32, + ) -> Result<(), wgc::command::ComputePassError>; + fn end_pipeline_statistics_query(&mut self, context: &wgc::global::Global); +} + +impl DynComputePass for wgc::command::ComputePass { + fn run(&mut self, context: &wgc::global::Global) -> Result<(), wgc::command::ComputePassError> { + context.command_encoder_run_compute_pass(self) + } + + fn set_bind_group( + &mut self, + context: &wgc::global::Global, + index: u32, + bind_group_id: wgc::id::BindGroupId, + offsets: &[wgt::DynamicOffset], + ) -> Result<(), wgc::command::ComputePassError> { + context.compute_pass_set_bind_group(self, index, bind_group_id, offsets) + } + + fn set_pipeline( + &mut self, + context: &wgc::global::Global, + pipeline_id: wgc::id::ComputePipelineId, + ) -> Result<(), wgc::command::ComputePassError> { + context.compute_pass_set_pipeline(self, pipeline_id) + } + + fn set_push_constant(&mut self, context: &wgc::global::Global, offset: u32, data: &[u8]) { + context.compute_pass_set_push_constant(self, offset, data) + } + + fn dispatch_workgroups( + &mut self, + context: &wgc::global::Global, + groups_x: u32, + groups_y: u32, + groups_z: u32, + ) { + context.compute_pass_dispatch_workgroups(self, groups_x, groups_y, groups_z) + } + + fn dispatch_workgroups_indirect( + &mut self, + context: &wgc::global::Global, + buffer_id: wgc::id::BufferId, + offset: wgt::BufferAddress, + ) -> Result<(), wgc::command::ComputePassError> { + context.compute_pass_dispatch_workgroups_indirect(self, buffer_id, offset) + } + + fn push_debug_group(&mut self, context: &wgc::global::Global, label: &str, color: u32) { + context.compute_pass_push_debug_group(self, label, color) + } + + fn pop_debug_group(&mut self, context: &wgc::global::Global) { + context.compute_pass_pop_debug_group(self) + } + + fn insert_debug_marker(&mut self, context: &wgc::global::Global, label: &str, color: u32) { + context.compute_pass_insert_debug_marker(self, label, color) + } + + fn write_timestamp( + &mut self, + context: &wgc::global::Global, + query_set_id: wgc::id::QuerySetId, + query_index: u32, + ) -> Result<(), wgc::command::ComputePassError> { + context.compute_pass_write_timestamp(self, query_set_id, query_index) + } + + fn begin_pipeline_statistics_query( + &mut self, + context: &wgc::global::Global, + query_set_id: wgc::id::QuerySetId, + query_index: u32, + ) -> Result<(), wgc::command::ComputePassError> { + context.compute_pass_begin_pipeline_statistics_query(self, query_set_id, query_index) + } + + fn end_pipeline_statistics_query(&mut self, context: &wgc::global::Global) { + context.compute_pass_end_pipeline_statistics_query(self) + } +} + #[derive(Debug)] pub struct CommandEncoder { error_sink: ErrorSink, @@ -1835,16 +1967,25 @@ impl crate::Context for ContextWgpuCore { end_of_pass_write_index: tw.end_of_pass_write_index, }); + fn create_dyn_compute_pass( + context: &ContextWgpuCore, + encoder: wgc::id::CommandEncoderId, + desc: &wgc::command::ComputePassDescriptor<'_>, + ) -> Box { + Box::new( + context + .0 + .command_encoder_create_compute_pass::(encoder, desc), + ) + } + ( Unused, Self::ComputePassData { - pass: gfx_select!(encoder => self.0.command_encoder_create_compute_pass( - *encoder, - &wgc::command::ComputePassDescriptor { - label: desc.label.map(Borrowed), - timestamp_writes: timestamp_writes.as_ref(), - }, - )), + pass: gfx_select!(encoder => create_dyn_compute_pass(self, *encoder, &wgc::command::ComputePassDescriptor { + label: desc.label.map(Borrowed), + timestamp_writes: timestamp_writes.as_ref(), + })), error_sink: encoder_data.error_sink.clone(), }, ) @@ -1857,9 +1998,7 @@ impl crate::Context for ContextWgpuCore { _pass: &mut Self::ComputePassId, pass_data: &mut Self::ComputePassData, ) { - if let Err(cause) = wgc::gfx_select!( - encoder => self.0.command_encoder_run_compute_pass(*encoder, &pass_data.pass) - ) { + if let Err(cause) = pass_data.pass.run(&self.0) { let name = wgc::gfx_select!(encoder => self.0.command_buffer_label(encoder.into_command_buffer_id())); self.handle_error( &encoder_data.error_sink, @@ -2322,9 +2461,7 @@ impl crate::Context for ContextWgpuCore { pipeline: &Self::ComputePipelineId, _pipeline_data: &Self::ComputePipelineData, ) { - let encoder = pass_data.pass.parent_id(); - if let Err(cause) = wgc::gfx_select!(encoder => self.0.compute_pass_set_pipeline(&mut pass_data.pass, *pipeline)) - { + if let Err(cause) = pass_data.pass.set_pipeline(&self.0, *pipeline) { self.handle_error_nolabel(&pass_data.error_sink, cause, "ComputePass::set_pipeline"); } } @@ -2338,8 +2475,9 @@ impl crate::Context for ContextWgpuCore { _bind_group_data: &Self::BindGroupData, offsets: &[wgt::DynamicOffset], ) { - let encoder = pass_data.pass.parent_id(); - if let Err(cause) = wgc::gfx_select!(encoder => self.0.compute_pass_set_bind_group(&mut pass_data.pass, index, *bind_group, offsets)) + if let Err(cause) = pass_data + .pass + .set_bind_group(&self.0, index, *bind_group, offsets) { self.handle_error_nolabel(&pass_data.error_sink, cause, "ComputePass::set_bind_group"); } @@ -2352,8 +2490,7 @@ impl crate::Context for ContextWgpuCore { offset: u32, data: &[u8], ) { - let encoder = pass_data.pass.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_set_push_constant(&mut pass_data.pass, offset, data)); + pass_data.pass.set_push_constant(&self.0, offset, data); } fn compute_pass_insert_debug_marker( @@ -2362,8 +2499,7 @@ impl crate::Context for ContextWgpuCore { pass_data: &mut Self::ComputePassData, label: &str, ) { - let encoder = pass_data.pass.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_insert_debug_marker(&mut pass_data.pass, label, 0)); + pass_data.pass.insert_debug_marker(&self.0, label, 0); } fn compute_pass_push_debug_group( @@ -2372,8 +2508,7 @@ impl crate::Context for ContextWgpuCore { pass_data: &mut Self::ComputePassData, group_label: &str, ) { - let encoder = pass_data.pass.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_push_debug_group(&mut pass_data.pass, group_label, 0)); + pass_data.pass.push_debug_group(&self.0, group_label, 0); } fn compute_pass_pop_debug_group( @@ -2381,8 +2516,7 @@ impl crate::Context for ContextWgpuCore { _pass: &mut Self::ComputePassId, pass_data: &mut Self::ComputePassData, ) { - let encoder = pass_data.pass.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_pop_debug_group(&mut pass_data.pass)); + pass_data.pass.pop_debug_group(&self.0); } fn compute_pass_write_timestamp( @@ -2393,8 +2527,9 @@ impl crate::Context for ContextWgpuCore { _query_set_data: &Self::QuerySetData, query_index: u32, ) { - let encoder = pass_data.pass.parent_id(); - if let Err(cause) = wgc::gfx_select!(encoder => self.0.compute_pass_write_timestamp(&mut pass_data.pass, *query_set, query_index)) + if let Err(cause) = pass_data + .pass + .write_timestamp(&self.0, *query_set, query_index) { self.handle_error_nolabel(&pass_data.error_sink, cause, "ComputePass::write_timestamp"); } @@ -2408,8 +2543,10 @@ impl crate::Context for ContextWgpuCore { _query_set_data: &Self::QuerySetData, query_index: u32, ) { - let encoder = pass_data.pass.parent_id(); - if let Err(cause) = wgc::gfx_select!(encoder => self.0.compute_pass_begin_pipeline_statistics_query(&mut pass_data.pass, *query_set, query_index)) + if let Err(cause) = + pass_data + .pass + .begin_pipeline_statistics_query(&self.0, *query_set, query_index) { self.handle_error_nolabel( &pass_data.error_sink, @@ -2424,8 +2561,7 @@ impl crate::Context for ContextWgpuCore { _pass: &mut Self::ComputePassId, pass_data: &mut Self::ComputePassData, ) { - let encoder = pass_data.pass.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_end_pipeline_statistics_query(&mut pass_data.pass)); + pass_data.pass.end_pipeline_statistics_query(&self.0); } fn compute_pass_dispatch_workgroups( @@ -2436,8 +2572,7 @@ impl crate::Context for ContextWgpuCore { y: u32, z: u32, ) { - let encoder = pass_data.pass.parent_id(); - wgc::gfx_select!(encoder => self.0.compute_pass_dispatch_workgroups(&mut pass_data.pass, x, y, z)); + pass_data.pass.dispatch_workgroups(&self.0, x, y, z); } fn compute_pass_dispatch_workgroups_indirect( @@ -2448,8 +2583,10 @@ impl crate::Context for ContextWgpuCore { _indirect_buffer_data: &Self::BufferData, indirect_offset: wgt::BufferAddress, ) { - let encoder = pass_data.pass.parent_id(); - if let Err(cause) = wgc::gfx_select!(encoder => self.0.compute_pass_dispatch_workgroups_indirect(&mut pass_data.pass, *indirect_buffer, indirect_offset)) + if let Err(cause) = + pass_data + .pass + .dispatch_workgroups_indirect(&self.0, *indirect_buffer, indirect_offset) { self.handle_error_nolabel( &pass_data.error_sink, From 4f3f1db8799d2d7b910571ef21744564af39d7e5 Mon Sep 17 00:00:00 2001 From: Andreas Reich Date: Sat, 20 Apr 2024 12:19:04 +0200 Subject: [PATCH 08/11] remove stray repr(C) --- wgpu-core/src/command/compute.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index 6ed2e5675b..154be34de7 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -68,7 +68,6 @@ impl fmt::Debug for ComputePass { } /// Describes the writing of timestamp values in a compute pass. -#[repr(C)] #[derive(Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct ComputePassTimestampWrites { From 5bfef4bfa08f7b520abb8319faf9254e5a0529c6 Mon Sep 17 00:00:00 2001 From: Andreas Reich Date: Sat, 20 Apr 2024 15:56:53 +0200 Subject: [PATCH 09/11] changelog entry --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b9fdf7783..f65f1cf935 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -122,6 +122,7 @@ Bottom level categories: - Added support for pipeline-overridable constants. By @teoxoy & @jimblandy in [#5500](https://github.com/gfx-rs/wgpu/pull/5500) - Add `SUBGROUP, SUBGROUP_VERTEX, SUBGROUP_BARRIER` features. By @exrook and @lichtso in [#5301](https://github.com/gfx-rs/wgpu/pull/5301) - Support disabling zero-initialization of workgroup local memory in compute shaders. By @DJMcNab in [#5508](https://github.com/gfx-rs/wgpu/pull/5508) +- `wgpu::ComputePass` recording methods (e.g. `wgpu::ComputePass:set_render_pipeline`) no longer impose a lifetime constraint passed in resources. By @wumpf in [#5569](https://github.com/gfx-rs/wgpu/pull/5569) #### GLES From 0d470a92cf6a99746f37c5940001b08715139b12 Mon Sep 17 00:00:00 2001 From: Andreas Reich Date: Sat, 20 Apr 2024 16:33:17 +0200 Subject: [PATCH 10/11] fix deno issues -> move DynComputePass into wgc --- deno_webgpu/command_encoder.rs | 5 +- deno_webgpu/compute_pass.rs | 67 ++++---- .../tests/compute_pass_resource_ownership.rs | 4 + wgpu-core/src/command/compute.rs | 8 + wgpu-core/src/command/dyn_compute_pass.rs | 136 ++++++++++++++++ wgpu-core/src/command/mod.rs | 5 +- wgpu-core/src/lib.rs | 11 -- wgpu/src/backend/wgpu_core.rs | 148 +----------------- 8 files changed, 185 insertions(+), 199 deletions(-) create mode 100644 wgpu-core/src/command/dyn_compute_pass.rs diff --git a/deno_webgpu/command_encoder.rs b/deno_webgpu/command_encoder.rs index 20dfe0db09..b82fba92ea 100644 --- a/deno_webgpu/command_encoder.rs +++ b/deno_webgpu/command_encoder.rs @@ -254,13 +254,14 @@ pub fn op_webgpu_command_encoder_begin_compute_pass( None }; + let instance = state.borrow::(); + let command_encoder = &command_encoder_resource.1; let descriptor = wgpu_core::command::ComputePassDescriptor { label: Some(label), timestamp_writes: timestamp_writes.as_ref(), }; - let compute_pass = - wgpu_core::command::ComputePass::new(command_encoder_resource.1, &descriptor); + let compute_pass = gfx_select!(command_encoder => instance.command_encoder_create_compute_pass_dyn(*command_encoder, &descriptor)); let rid = state .resource_table diff --git a/deno_webgpu/compute_pass.rs b/deno_webgpu/compute_pass.rs index 2cdea2c8f2..fb499e7e06 100644 --- a/deno_webgpu/compute_pass.rs +++ b/deno_webgpu/compute_pass.rs @@ -10,7 +10,9 @@ use std::cell::RefCell; use super::error::WebGpuResult; -pub(crate) struct WebGpuComputePass(pub(crate) RefCell); +pub(crate) struct WebGpuComputePass( + pub(crate) RefCell>, +); impl Resource for WebGpuComputePass { fn name(&self) -> Cow { "webGPUComputePass".into() @@ -31,10 +33,10 @@ pub fn op_webgpu_compute_pass_set_pipeline( .resource_table .get::(compute_pass_rid)?; - wgpu_core::command::compute_commands::wgpu_compute_pass_set_pipeline( - &mut compute_pass_resource.0.borrow_mut(), - compute_pipeline_resource.1, - ); + compute_pass_resource + .0 + .borrow_mut() + .set_pipeline(state.borrow(), compute_pipeline_resource.1)?; Ok(WebGpuResult::empty()) } @@ -52,12 +54,10 @@ pub fn op_webgpu_compute_pass_dispatch_workgroups( .resource_table .get::(compute_pass_rid)?; - wgpu_core::command::compute_commands::wgpu_compute_pass_dispatch_workgroups( - &mut compute_pass_resource.0.borrow_mut(), - x, - y, - z, - ); + compute_pass_resource + .0 + .borrow_mut() + .dispatch_workgroups(state.borrow(), x, y, z); Ok(WebGpuResult::empty()) } @@ -77,11 +77,10 @@ pub fn op_webgpu_compute_pass_dispatch_workgroups_indirect( .resource_table .get::(compute_pass_rid)?; - wgpu_core::command::compute_commands::wgpu_compute_pass_dispatch_workgroups_indirect( - &mut compute_pass_resource.0.borrow_mut(), - buffer_resource.1, - indirect_offset, - ); + compute_pass_resource + .0 + .borrow_mut() + .dispatch_workgroups_indirect(state.borrow(), buffer_resource.1, indirect_offset)?; Ok(WebGpuResult::empty()) } @@ -90,24 +89,15 @@ pub fn op_webgpu_compute_pass_dispatch_workgroups_indirect( #[serde] pub fn op_webgpu_compute_pass_end( state: &mut OpState, - #[smi] command_encoder_rid: ResourceId, #[smi] compute_pass_rid: ResourceId, ) -> Result { - let command_encoder_resource = - state - .resource_table - .get::(command_encoder_rid)?; - let command_encoder = command_encoder_resource.1; let compute_pass_resource = state .resource_table .take::(compute_pass_rid)?; - let compute_pass = &compute_pass_resource.0.borrow(); - let instance = state.borrow::(); - gfx_ok!(command_encoder => instance.command_encoder_run_compute_pass( - command_encoder, - compute_pass - )) + compute_pass_resource.0.borrow_mut().run(state.borrow())?; + + Ok(WebGpuResult::empty()) } #[op2] @@ -137,12 +127,12 @@ pub fn op_webgpu_compute_pass_set_bind_group( let dynamic_offsets_data: &[u32] = &dynamic_offsets_data[start..start + len]; - wgpu_core::command::compute_commands::wgpu_compute_pass_set_bind_group( - &mut compute_pass_resource.0.borrow_mut(), + compute_pass_resource.0.borrow_mut().set_bind_group( + state.borrow(), index, bind_group_resource.1, dynamic_offsets_data, - ); + )?; Ok(WebGpuResult::empty()) } @@ -158,8 +148,8 @@ pub fn op_webgpu_compute_pass_push_debug_group( .resource_table .get::(compute_pass_rid)?; - wgpu_core::command::compute_commands::wgpu_compute_pass_push_debug_group( - &mut compute_pass_resource.0.borrow_mut(), + compute_pass_resource.0.borrow_mut().push_debug_group( + state.borrow(), group_label, 0, // wgpu#975 ); @@ -177,9 +167,10 @@ pub fn op_webgpu_compute_pass_pop_debug_group( .resource_table .get::(compute_pass_rid)?; - wgpu_core::command::compute_commands::wgpu_compute_pass_pop_debug_group( - &mut compute_pass_resource.0.borrow_mut(), - ); + compute_pass_resource + .0 + .borrow_mut() + .pop_debug_group(state.borrow()); Ok(WebGpuResult::empty()) } @@ -195,8 +186,8 @@ pub fn op_webgpu_compute_pass_insert_debug_marker( .resource_table .get::(compute_pass_rid)?; - wgpu_core::command::compute_commands::wgpu_compute_pass_insert_debug_marker( - &mut compute_pass_resource.0.borrow_mut(), + compute_pass_resource.0.borrow_mut().insert_debug_marker( + state.borrow(), marker_label, 0, // wgpu#975 ); diff --git a/tests/tests/compute_pass_resource_ownership.rs b/tests/tests/compute_pass_resource_ownership.rs index 863c03b304..8a07b76201 100644 --- a/tests/tests/compute_pass_resource_ownership.rs +++ b/tests/tests/compute_pass_resource_ownership.rs @@ -115,6 +115,10 @@ async fn compute_pass_resource_ownership(ctx: TestingContext) { cpass.set_bind_group(0, &bind_group, &[]); cpass.dispatch_workgroups_indirect(&indirect_buffer, 0); + // TODO: + // write_timestamp + // begin_pipeline_statistics_query + // Now drop all resources we set. Then do a device poll to make sure the resources are really not dropped too early, no matter what. drop(pipeline); drop(bind_group); diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index 154be34de7..997c62e8b1 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -271,6 +271,14 @@ impl Global { ComputePass::new(parent_id, desc) } + pub fn command_encoder_create_compute_pass_dyn( + &self, + parent_id: id::CommandEncoderId, + desc: &ComputePassDescriptor, + ) -> Box { + Box::new(ComputePass::::new(parent_id, desc)) + } + pub fn command_encoder_run_compute_pass( &self, pass: &ComputePass, diff --git a/wgpu-core/src/command/dyn_compute_pass.rs b/wgpu-core/src/command/dyn_compute_pass.rs new file mode 100644 index 0000000000..b7ffea3d42 --- /dev/null +++ b/wgpu-core/src/command/dyn_compute_pass.rs @@ -0,0 +1,136 @@ +use wgt::WasmNotSendSync; + +use crate::{global, hal_api::HalApi, id}; + +use super::{ComputePass, ComputePassError}; + +/// Trait for type erasing ComputePass. +// TODO(#5124): wgpu-core's ComputePass trait should not be hal type dependent. +// Practically speaking this allows us merge gfx_select with type erasure: +// The alternative would be to introduce ComputePassId which then first needs to be looked up and then dispatch via gfx_select. +pub trait DynComputePass: std::fmt::Debug + WasmNotSendSync { + fn run(&mut self, context: &global::Global) -> Result<(), ComputePassError>; + fn set_bind_group( + &mut self, + context: &global::Global, + index: u32, + bind_group_id: id::BindGroupId, + offsets: &[wgt::DynamicOffset], + ) -> Result<(), ComputePassError>; + fn set_pipeline( + &mut self, + context: &global::Global, + pipeline_id: id::ComputePipelineId, + ) -> Result<(), ComputePassError>; + fn set_push_constant(&mut self, context: &global::Global, offset: u32, data: &[u8]); + fn dispatch_workgroups( + &mut self, + context: &global::Global, + groups_x: u32, + groups_y: u32, + groups_z: u32, + ); + fn dispatch_workgroups_indirect( + &mut self, + context: &global::Global, + buffer_id: id::BufferId, + offset: wgt::BufferAddress, + ) -> Result<(), ComputePassError>; + fn push_debug_group(&mut self, context: &global::Global, label: &str, color: u32); + fn pop_debug_group(&mut self, context: &global::Global); + fn insert_debug_marker(&mut self, context: &global::Global, label: &str, color: u32); + fn write_timestamp( + &mut self, + context: &global::Global, + query_set_id: id::QuerySetId, + query_index: u32, + ) -> Result<(), ComputePassError>; + fn begin_pipeline_statistics_query( + &mut self, + context: &global::Global, + query_set_id: id::QuerySetId, + query_index: u32, + ) -> Result<(), ComputePassError>; + fn end_pipeline_statistics_query(&mut self, context: &global::Global); +} + +impl DynComputePass for ComputePass { + fn run(&mut self, context: &global::Global) -> Result<(), ComputePassError> { + context.command_encoder_run_compute_pass(self) + } + + fn set_bind_group( + &mut self, + context: &global::Global, + index: u32, + bind_group_id: id::BindGroupId, + offsets: &[wgt::DynamicOffset], + ) -> Result<(), ComputePassError> { + context.compute_pass_set_bind_group(self, index, bind_group_id, offsets) + } + + fn set_pipeline( + &mut self, + context: &global::Global, + pipeline_id: id::ComputePipelineId, + ) -> Result<(), ComputePassError> { + context.compute_pass_set_pipeline(self, pipeline_id) + } + + fn set_push_constant(&mut self, context: &global::Global, offset: u32, data: &[u8]) { + context.compute_pass_set_push_constant(self, offset, data) + } + + fn dispatch_workgroups( + &mut self, + context: &global::Global, + groups_x: u32, + groups_y: u32, + groups_z: u32, + ) { + context.compute_pass_dispatch_workgroups(self, groups_x, groups_y, groups_z) + } + + fn dispatch_workgroups_indirect( + &mut self, + context: &global::Global, + buffer_id: id::BufferId, + offset: wgt::BufferAddress, + ) -> Result<(), ComputePassError> { + context.compute_pass_dispatch_workgroups_indirect(self, buffer_id, offset) + } + + fn push_debug_group(&mut self, context: &global::Global, label: &str, color: u32) { + context.compute_pass_push_debug_group(self, label, color) + } + + fn pop_debug_group(&mut self, context: &global::Global) { + context.compute_pass_pop_debug_group(self) + } + + fn insert_debug_marker(&mut self, context: &global::Global, label: &str, color: u32) { + context.compute_pass_insert_debug_marker(self, label, color) + } + + fn write_timestamp( + &mut self, + context: &global::Global, + query_set_id: id::QuerySetId, + query_index: u32, + ) -> Result<(), ComputePassError> { + context.compute_pass_write_timestamp(self, query_set_id, query_index) + } + + fn begin_pipeline_statistics_query( + &mut self, + context: &global::Global, + query_set_id: id::QuerySetId, + query_index: u32, + ) -> Result<(), ComputePassError> { + context.compute_pass_begin_pipeline_statistics_query(self, query_set_id, query_index) + } + + fn end_pipeline_statistics_query(&mut self, context: &global::Global) { + context.compute_pass_end_pipeline_statistics_query(self) + } +} diff --git a/wgpu-core/src/command/mod.rs b/wgpu-core/src/command/mod.rs index d53f47bf42..bbf9adc552 100644 --- a/wgpu-core/src/command/mod.rs +++ b/wgpu-core/src/command/mod.rs @@ -5,6 +5,7 @@ mod clear; mod compute; mod compute_command; mod draw; +mod dyn_compute_pass; mod memory_init; mod query; mod render; @@ -14,8 +15,8 @@ use std::sync::Arc; pub(crate) use self::clear::clear_texture; pub use self::{ - bundle::*, clear::ClearError, compute::*, compute_command::ComputeCommand, draw::*, query::*, - render::*, transfer::*, + bundle::*, clear::ClearError, compute::*, compute_command::ComputeCommand, draw::*, + dyn_compute_pass::DynComputePass, query::*, render::*, transfer::*, }; pub(crate) use allocator::CommandAllocator; diff --git a/wgpu-core/src/lib.rs b/wgpu-core/src/lib.rs index c09bb7119b..032d85a4bc 100644 --- a/wgpu-core/src/lib.rs +++ b/wgpu-core/src/lib.rs @@ -302,17 +302,6 @@ macro_rules! gfx_select { other => panic!("Unexpected backend {:?}", other), } }; - - ($id:expr => $method:ident $params:tt) => { - match $id.backend() { - wgt::Backend::Vulkan => $crate::gfx_if_vulkan!($method::<$crate::api::Vulkan> $params), - wgt::Backend::Metal => $crate::gfx_if_metal!($method::<$crate::api::Metal> $params), - wgt::Backend::Dx12 => $crate::gfx_if_dx12!($method::<$crate::api::Dx12> $params), - wgt::Backend::Gl => $crate::gfx_if_gles!($method::<$crate::api::Gles> $params), - wgt::Backend::Empty => $crate::gfx_if_empty!($method::<$crate::api::Empty> $params), - other => panic!("Unexpected backend {:?}", other), - } - }; } #[cfg(feature = "api_log_info")] diff --git a/wgpu/src/backend/wgpu_core.rs b/wgpu/src/backend/wgpu_core.rs index 55469ba859..efedee66cd 100644 --- a/wgpu/src/backend/wgpu_core.rs +++ b/wgpu/src/backend/wgpu_core.rs @@ -25,7 +25,6 @@ use std::{ use wgc::{ command::{bundle_ffi::*, render_commands::*}, gfx_select, - hal_api::HalApi, id::CommandEncoderId, }; use wgc::{device::DeviceLostClosure, id::TextureViewId}; @@ -473,141 +472,10 @@ impl Queue { #[derive(Debug)] pub struct ComputePass { - pass: Box, + pass: Box, error_sink: ErrorSink, } -/// Trait for type erasing ComputePass. -// TODO(#5124): wgpu-core's ComputePass trait should not be hal type dependent. -// Practically speaking this allows us merge gfx_select with type erasure: -// The alternative would be to introduce ComputePassId which then first needs to be looked up and then dispatch via gfx_select. -trait DynComputePass: std::fmt::Debug + WasmNotSendSync { - fn run(&mut self, context: &wgc::global::Global) -> Result<(), wgc::command::ComputePassError>; - fn set_bind_group( - &mut self, - context: &wgc::global::Global, - index: u32, - bind_group_id: wgc::id::BindGroupId, - offsets: &[wgt::DynamicOffset], - ) -> Result<(), wgc::command::ComputePassError>; - fn set_pipeline( - &mut self, - context: &wgc::global::Global, - pipeline_id: wgc::id::ComputePipelineId, - ) -> Result<(), wgc::command::ComputePassError>; - fn set_push_constant(&mut self, context: &wgc::global::Global, offset: u32, data: &[u8]); - fn dispatch_workgroups( - &mut self, - context: &wgc::global::Global, - groups_x: u32, - groups_y: u32, - groups_z: u32, - ); - fn dispatch_workgroups_indirect( - &mut self, - context: &wgc::global::Global, - buffer_id: wgc::id::BufferId, - offset: wgt::BufferAddress, - ) -> Result<(), wgc::command::ComputePassError>; - fn push_debug_group(&mut self, context: &wgc::global::Global, label: &str, color: u32); - fn pop_debug_group(&mut self, context: &wgc::global::Global); - fn insert_debug_marker(&mut self, context: &wgc::global::Global, label: &str, color: u32); - fn write_timestamp( - &mut self, - context: &wgc::global::Global, - query_set_id: wgc::id::QuerySetId, - query_index: u32, - ) -> Result<(), wgc::command::ComputePassError>; - fn begin_pipeline_statistics_query( - &mut self, - context: &wgc::global::Global, - query_set_id: wgc::id::QuerySetId, - query_index: u32, - ) -> Result<(), wgc::command::ComputePassError>; - fn end_pipeline_statistics_query(&mut self, context: &wgc::global::Global); -} - -impl DynComputePass for wgc::command::ComputePass { - fn run(&mut self, context: &wgc::global::Global) -> Result<(), wgc::command::ComputePassError> { - context.command_encoder_run_compute_pass(self) - } - - fn set_bind_group( - &mut self, - context: &wgc::global::Global, - index: u32, - bind_group_id: wgc::id::BindGroupId, - offsets: &[wgt::DynamicOffset], - ) -> Result<(), wgc::command::ComputePassError> { - context.compute_pass_set_bind_group(self, index, bind_group_id, offsets) - } - - fn set_pipeline( - &mut self, - context: &wgc::global::Global, - pipeline_id: wgc::id::ComputePipelineId, - ) -> Result<(), wgc::command::ComputePassError> { - context.compute_pass_set_pipeline(self, pipeline_id) - } - - fn set_push_constant(&mut self, context: &wgc::global::Global, offset: u32, data: &[u8]) { - context.compute_pass_set_push_constant(self, offset, data) - } - - fn dispatch_workgroups( - &mut self, - context: &wgc::global::Global, - groups_x: u32, - groups_y: u32, - groups_z: u32, - ) { - context.compute_pass_dispatch_workgroups(self, groups_x, groups_y, groups_z) - } - - fn dispatch_workgroups_indirect( - &mut self, - context: &wgc::global::Global, - buffer_id: wgc::id::BufferId, - offset: wgt::BufferAddress, - ) -> Result<(), wgc::command::ComputePassError> { - context.compute_pass_dispatch_workgroups_indirect(self, buffer_id, offset) - } - - fn push_debug_group(&mut self, context: &wgc::global::Global, label: &str, color: u32) { - context.compute_pass_push_debug_group(self, label, color) - } - - fn pop_debug_group(&mut self, context: &wgc::global::Global) { - context.compute_pass_pop_debug_group(self) - } - - fn insert_debug_marker(&mut self, context: &wgc::global::Global, label: &str, color: u32) { - context.compute_pass_insert_debug_marker(self, label, color) - } - - fn write_timestamp( - &mut self, - context: &wgc::global::Global, - query_set_id: wgc::id::QuerySetId, - query_index: u32, - ) -> Result<(), wgc::command::ComputePassError> { - context.compute_pass_write_timestamp(self, query_set_id, query_index) - } - - fn begin_pipeline_statistics_query( - &mut self, - context: &wgc::global::Global, - query_set_id: wgc::id::QuerySetId, - query_index: u32, - ) -> Result<(), wgc::command::ComputePassError> { - context.compute_pass_begin_pipeline_statistics_query(self, query_set_id, query_index) - } - - fn end_pipeline_statistics_query(&mut self, context: &wgc::global::Global) { - context.compute_pass_end_pipeline_statistics_query(self) - } -} - #[derive(Debug)] pub struct CommandEncoder { error_sink: ErrorSink, @@ -1967,22 +1835,10 @@ impl crate::Context for ContextWgpuCore { end_of_pass_write_index: tw.end_of_pass_write_index, }); - fn create_dyn_compute_pass( - context: &ContextWgpuCore, - encoder: wgc::id::CommandEncoderId, - desc: &wgc::command::ComputePassDescriptor<'_>, - ) -> Box { - Box::new( - context - .0 - .command_encoder_create_compute_pass::(encoder, desc), - ) - } - ( Unused, Self::ComputePassData { - pass: gfx_select!(encoder => create_dyn_compute_pass(self, *encoder, &wgc::command::ComputePassDescriptor { + pass: gfx_select!(encoder => self.0.command_encoder_create_compute_pass_dyn(*encoder, &wgc::command::ComputePassDescriptor { label: desc.label.map(Borrowed), timestamp_writes: timestamp_writes.as_ref(), })), From 22c1e68c91ccfdaab7b320331bba60ea594f3174 Mon Sep 17 00:00:00 2001 From: Andreas Reich Date: Sun, 21 Apr 2024 10:21:44 +0200 Subject: [PATCH 11/11] split out resources setup from test --- .../tests/compute_pass_resource_ownership.rs | 110 +++++++++++------- 1 file changed, 71 insertions(+), 39 deletions(-) diff --git a/tests/tests/compute_pass_resource_ownership.rs b/tests/tests/compute_pass_resource_ownership.rs index 8a07b76201..6612ad0068 100644 --- a/tests/tests/compute_pass_resource_ownership.rs +++ b/tests/tests/compute_pass_resource_ownership.rs @@ -1,9 +1,12 @@ -//! Tests that compute passes take ownership of resources that are passed in. +//! Tests that compute passes take ownership of resources that are associated with. //! I.e. once a resource is passed in to a compute pass, it can be dropped. //! //! TODO: Test doesn't check on timestamp writes & pipeline statistics queries yet. //! (Not important as long as they are lifetime constrained to the command encoder, //! but once we lift this constraint, we should add tests for this as well!) +//! TODO: Also should test resource ownership for: +//! * write_timestamp +//! * begin_pipeline_statistics_query use std::num::NonZeroU64; @@ -25,6 +28,66 @@ static COMPUTE_PASS_RESOURCE_OWNERSHIP: GpuTestConfiguration = GpuTestConfigurat .run_async(compute_pass_resource_ownership); async fn compute_pass_resource_ownership(ctx: TestingContext) { + let ResourceSetup { + gpu_buffer, + cpu_buffer, + buffer_size, + indirect_buffer, + bind_group, + pipeline, + } = resource_setup(&ctx); + + let mut encoder = ctx + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("encoder"), + }); + + { + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("compute_pass"), + timestamp_writes: None, // TODO: See description above, we should test this as well once we lift the lifetime bound. + }); + cpass.set_pipeline(&pipeline); + cpass.set_bind_group(0, &bind_group, &[]); + cpass.dispatch_workgroups_indirect(&indirect_buffer, 0); + + // Now drop all resources we set. Then do a device poll to make sure the resources are really not dropped too early, no matter what. + drop(pipeline); + drop(bind_group); + drop(indirect_buffer); + ctx.async_poll(wgpu::Maintain::wait()) + .await + .panic_on_timeout(); + } + + // Ensure that the compute pass still executed normally. + encoder.copy_buffer_to_buffer(&gpu_buffer, 0, &cpu_buffer, 0, buffer_size); + ctx.queue.submit([encoder.finish()]); + cpu_buffer.slice(..).map_async(wgpu::MapMode::Read, |_| ()); + ctx.async_poll(wgpu::Maintain::wait()) + .await + .panic_on_timeout(); + + let data = cpu_buffer.slice(..).get_mapped_range(); + + let floats: &[f32] = bytemuck::cast_slice(&data); + assert_eq!(floats, [2.0, 4.0, 6.0, 8.0]); +} + +// Setup ------------------------------------------------------------ + +struct ResourceSetup { + gpu_buffer: wgpu::Buffer, + cpu_buffer: wgpu::Buffer, + buffer_size: u64, + + indirect_buffer: wgpu::Buffer, + bind_group: wgpu::BindGroup, + pipeline: wgpu::ComputePipeline, +} + +fn resource_setup(ctx: &TestingContext) -> ResourceSetup { let sm = ctx .device .create_shader_module(wgpu::ShaderModuleDescriptor { @@ -100,43 +163,12 @@ async fn compute_pass_resource_ownership(ctx: TestingContext) { compilation_options: Default::default(), }); - let mut encoder = ctx - .device - .create_command_encoder(&wgpu::CommandEncoderDescriptor { - label: Some("encoder"), - }); - - { - let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { - label: Some("compute_pass"), - timestamp_writes: None, // TODO: See description above, we should test this as well once we lift the lifetime bound. - }); - cpass.set_pipeline(&pipeline); - cpass.set_bind_group(0, &bind_group, &[]); - cpass.dispatch_workgroups_indirect(&indirect_buffer, 0); - - // TODO: - // write_timestamp - // begin_pipeline_statistics_query - - // Now drop all resources we set. Then do a device poll to make sure the resources are really not dropped too early, no matter what. - drop(pipeline); - drop(bind_group); - drop(indirect_buffer); - ctx.async_poll(wgpu::Maintain::wait()) - .await - .panic_on_timeout(); + ResourceSetup { + gpu_buffer, + cpu_buffer, + buffer_size, + indirect_buffer, + bind_group, + pipeline, } - - encoder.copy_buffer_to_buffer(&gpu_buffer, 0, &cpu_buffer, 0, buffer_size); - ctx.queue.submit([encoder.finish()]); - cpu_buffer.slice(..).map_async(wgpu::MapMode::Read, |_| ()); - ctx.async_poll(wgpu::Maintain::wait()) - .await - .panic_on_timeout(); - - let data = cpu_buffer.slice(..).get_mapped_range(); - - let floats: &[f32] = bytemuck::cast_slice(&data); - assert_eq!(floats, [2.0, 4.0, 6.0, 8.0]); }