Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove lifetime constraints from wgpu::ComputePass methods #5570

Merged
merged 12 commits into from
May 14, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ Bottom level categories:
- Added support for pipeline-overridable constants. By @teoxoy & @jimblandy in [#5500](https://github.com/gfx-rs/wgpu/pull/5500)
- Add `SUBGROUP, SUBGROUP_VERTEX, SUBGROUP_BARRIER` features. By @exrook and @lichtso in [#5301](https://github.com/gfx-rs/wgpu/pull/5301)
- Support disabling zero-initialization of workgroup local memory in compute shaders. By @DJMcNab in [#5508](https://github.com/gfx-rs/wgpu/pull/5508)
- `wgpu::ComputePass` recording methods (e.g. `wgpu::ComputePass:set_render_pipeline`) no longer impose a lifetime constraint passed in resources. By @wumpf in [#5569](https://github.com/gfx-rs/wgpu/pull/5569)

#### GLES

Expand Down
5 changes: 3 additions & 2 deletions deno_webgpu/command_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,14 @@ pub fn op_webgpu_command_encoder_begin_compute_pass(
None
};

let instance = state.borrow::<super::Instance>();
let command_encoder = &command_encoder_resource.1;
let descriptor = wgpu_core::command::ComputePassDescriptor {
label: Some(label),
timestamp_writes: timestamp_writes.as_ref(),
};

let compute_pass =
wgpu_core::command::ComputePass::new(command_encoder_resource.1, &descriptor);
let compute_pass = gfx_select!(command_encoder => instance.command_encoder_create_compute_pass_dyn(*command_encoder, &descriptor));

let rid = state
.resource_table
Expand Down
67 changes: 29 additions & 38 deletions deno_webgpu/compute_pass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ use std::cell::RefCell;

use super::error::WebGpuResult;

pub(crate) struct WebGpuComputePass(pub(crate) RefCell<wgpu_core::command::ComputePass>);
pub(crate) struct WebGpuComputePass(
pub(crate) RefCell<Box<dyn wgpu_core::command::DynComputePass>>,
);
impl Resource for WebGpuComputePass {
fn name(&self) -> Cow<str> {
"webGPUComputePass".into()
Expand All @@ -31,10 +33,10 @@ pub fn op_webgpu_compute_pass_set_pipeline(
.resource_table
.get::<WebGpuComputePass>(compute_pass_rid)?;

wgpu_core::command::compute_commands::wgpu_compute_pass_set_pipeline(
&mut compute_pass_resource.0.borrow_mut(),
compute_pipeline_resource.1,
);
compute_pass_resource
.0
.borrow_mut()
.set_pipeline(state.borrow(), compute_pipeline_resource.1)?;

Ok(WebGpuResult::empty())
}
Expand All @@ -52,12 +54,10 @@ pub fn op_webgpu_compute_pass_dispatch_workgroups(
.resource_table
.get::<WebGpuComputePass>(compute_pass_rid)?;

wgpu_core::command::compute_commands::wgpu_compute_pass_dispatch_workgroups(
&mut compute_pass_resource.0.borrow_mut(),
x,
y,
z,
);
compute_pass_resource
.0
.borrow_mut()
.dispatch_workgroups(state.borrow(), x, y, z);

Ok(WebGpuResult::empty())
}
Expand All @@ -77,11 +77,10 @@ pub fn op_webgpu_compute_pass_dispatch_workgroups_indirect(
.resource_table
.get::<WebGpuComputePass>(compute_pass_rid)?;

wgpu_core::command::compute_commands::wgpu_compute_pass_dispatch_workgroups_indirect(
&mut compute_pass_resource.0.borrow_mut(),
buffer_resource.1,
indirect_offset,
);
compute_pass_resource
.0
.borrow_mut()
.dispatch_workgroups_indirect(state.borrow(), buffer_resource.1, indirect_offset)?;

Ok(WebGpuResult::empty())
}
Expand All @@ -90,24 +89,15 @@ pub fn op_webgpu_compute_pass_dispatch_workgroups_indirect(
#[serde]
pub fn op_webgpu_compute_pass_end(
state: &mut OpState,
#[smi] command_encoder_rid: ResourceId,
#[smi] compute_pass_rid: ResourceId,
) -> Result<WebGpuResult, AnyError> {
let command_encoder_resource =
state
.resource_table
.get::<super::command_encoder::WebGpuCommandEncoder>(command_encoder_rid)?;
let command_encoder = command_encoder_resource.1;
let compute_pass_resource = state
.resource_table
.take::<WebGpuComputePass>(compute_pass_rid)?;
let compute_pass = &compute_pass_resource.0.borrow();
let instance = state.borrow::<super::Instance>();

gfx_ok!(command_encoder => instance.command_encoder_run_compute_pass(
command_encoder,
compute_pass
))
compute_pass_resource.0.borrow_mut().run(state.borrow())?;

Ok(WebGpuResult::empty())
}

#[op2]
Expand Down Expand Up @@ -137,12 +127,12 @@ pub fn op_webgpu_compute_pass_set_bind_group(

let dynamic_offsets_data: &[u32] = &dynamic_offsets_data[start..start + len];

wgpu_core::command::compute_commands::wgpu_compute_pass_set_bind_group(
&mut compute_pass_resource.0.borrow_mut(),
compute_pass_resource.0.borrow_mut().set_bind_group(
state.borrow(),
index,
bind_group_resource.1,
dynamic_offsets_data,
);
)?;

Ok(WebGpuResult::empty())
}
Expand All @@ -158,8 +148,8 @@ pub fn op_webgpu_compute_pass_push_debug_group(
.resource_table
.get::<WebGpuComputePass>(compute_pass_rid)?;

wgpu_core::command::compute_commands::wgpu_compute_pass_push_debug_group(
&mut compute_pass_resource.0.borrow_mut(),
compute_pass_resource.0.borrow_mut().push_debug_group(
state.borrow(),
group_label,
0, // wgpu#975
);
Expand All @@ -177,9 +167,10 @@ pub fn op_webgpu_compute_pass_pop_debug_group(
.resource_table
.get::<WebGpuComputePass>(compute_pass_rid)?;

wgpu_core::command::compute_commands::wgpu_compute_pass_pop_debug_group(
&mut compute_pass_resource.0.borrow_mut(),
);
compute_pass_resource
.0
.borrow_mut()
.pop_debug_group(state.borrow());

Ok(WebGpuResult::empty())
}
Expand All @@ -195,8 +186,8 @@ pub fn op_webgpu_compute_pass_insert_debug_marker(
.resource_table
.get::<WebGpuComputePass>(compute_pass_rid)?;

wgpu_core::command::compute_commands::wgpu_compute_pass_insert_debug_marker(
&mut compute_pass_resource.0.borrow_mut(),
compute_pass_resource.0.borrow_mut().insert_debug_marker(
state.borrow(),
marker_label,
0, // wgpu#975
);
Expand Down
174 changes: 174 additions & 0 deletions tests/tests/compute_pass_resource_ownership.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
//! Tests that compute passes take ownership of resources that are associated with.
//! I.e. once a resource is passed in to a compute pass, it can be dropped.
//!
//! TODO: Test doesn't check on timestamp writes & pipeline statistics queries yet.
//! (Not important as long as they are lifetime constrained to the command encoder,
//! but once we lift this constraint, we should add tests for this as well!)
//! TODO: Also should test resource ownership for:
//! * write_timestamp
//! * begin_pipeline_statistics_query

use std::num::NonZeroU64;

use wgpu::util::DeviceExt as _;
use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext};

const SHADER_SRC: &str = "
@group(0) @binding(0)
var<storage, read_write> buffer: array<vec4f>;

@compute @workgroup_size(1, 1, 1) fn main() {
buffer[0] *= 2.0;
}
";

#[gpu_test]
static COMPUTE_PASS_RESOURCE_OWNERSHIP: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(TestParameters::default().test_features_limits())
.run_async(compute_pass_resource_ownership);

async fn compute_pass_resource_ownership(ctx: TestingContext) {
let ResourceSetup {
gpu_buffer,
cpu_buffer,
buffer_size,
indirect_buffer,
bind_group,
pipeline,
} = resource_setup(&ctx);

let mut encoder = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("encoder"),
});

{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("compute_pass"),
timestamp_writes: None, // TODO: See description above, we should test this as well once we lift the lifetime bound.
});
cpass.set_pipeline(&pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.dispatch_workgroups_indirect(&indirect_buffer, 0);

// Now drop all resources we set. Then do a device poll to make sure the resources are really not dropped too early, no matter what.
drop(pipeline);
drop(bind_group);
drop(indirect_buffer);
ctx.async_poll(wgpu::Maintain::wait())
.await
.panic_on_timeout();
}

// Ensure that the compute pass still executed normally.
encoder.copy_buffer_to_buffer(&gpu_buffer, 0, &cpu_buffer, 0, buffer_size);
ctx.queue.submit([encoder.finish()]);
cpu_buffer.slice(..).map_async(wgpu::MapMode::Read, |_| ());
ctx.async_poll(wgpu::Maintain::wait())
.await
.panic_on_timeout();

let data = cpu_buffer.slice(..).get_mapped_range();

let floats: &[f32] = bytemuck::cast_slice(&data);
assert_eq!(floats, [2.0, 4.0, 6.0, 8.0]);
}

// Setup ------------------------------------------------------------

struct ResourceSetup {
gpu_buffer: wgpu::Buffer,
cpu_buffer: wgpu::Buffer,
buffer_size: u64,

indirect_buffer: wgpu::Buffer,
bind_group: wgpu::BindGroup,
pipeline: wgpu::ComputePipeline,
}

fn resource_setup(ctx: &TestingContext) -> ResourceSetup {
let sm = ctx
Wumpf marked this conversation as resolved.
Show resolved Hide resolved
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("shader"),
source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()),
});

let buffer_size = 4 * std::mem::size_of::<f32>() as u64;

let bgl = ctx
.device
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("bind_group_layout"),
entries: &[wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: NonZeroU64::new(buffer_size),
},
count: None,
}],
});

let gpu_buffer = ctx
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("gpu_buffer"),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
contents: bytemuck::bytes_of(&[1.0_f32, 2.0, 3.0, 4.0]),
});

let cpu_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("cpu_buffer"),
size: buffer_size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});

let indirect_buffer = ctx
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("gpu_buffer"),
usage: wgpu::BufferUsages::INDIRECT,
contents: wgpu::util::DispatchIndirectArgs { x: 1, y: 1, z: 1 }.as_bytes(),
});

let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("bind_group"),
layout: &bgl,
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: gpu_buffer.as_entire_binding(),
}],
});

let pipeline_layout = ctx
.device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("pipeline_layout"),
bind_group_layouts: &[&bgl],
push_constant_ranges: &[],
});

let pipeline = ctx
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("pipeline"),
layout: Some(&pipeline_layout),
module: &sm,
entry_point: "main",
compilation_options: Default::default(),
});

ResourceSetup {
gpu_buffer,
cpu_buffer,
buffer_size,
indirect_buffer,
bind_group,
pipeline,
}
}
1 change: 1 addition & 0 deletions tests/tests/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mod buffer;
mod buffer_copy;
mod buffer_usages;
mod clear_texture;
mod compute_pass_resource_ownership;
mod create_surface_error;
mod device;
mod encoder;
Expand Down
5 changes: 2 additions & 3 deletions wgpu-core/src/command/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::{
binding_model::{BindGroup, LateMinBufferBindingSizeMismatch, PipelineLayout},
device::SHADER_STAGE_COUNT,
hal_api::HalApi,
id::BindGroupId,
pipeline::LateSizedBufferGroup,
resource::Resource,
};
Expand Down Expand Up @@ -359,11 +358,11 @@ impl<A: HalApi> Binder<A> {
&self.payloads[bind_range]
}

pub(super) fn list_active(&self) -> impl Iterator<Item = BindGroupId> + '_ {
pub(super) fn list_active<'a>(&'a self) -> impl Iterator<Item = &'a Arc<BindGroup<A>>> + '_ {
let payloads = &self.payloads;
self.manager
.list_active()
.map(move |index| payloads[index].group.as_ref().unwrap().as_info().id())
.map(move |index| payloads[index].group.as_ref().unwrap())
}

pub(super) fn invalid_mask(&self) -> BindGroupMask {
Expand Down