/
sort_indirect.rs
103 lines (81 loc) · 3.13 KB
/
sort_indirect.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
// this example creates an array with 10 key-value (f32,u32) pairs and sorts them on the gpu
// Important: sorting by f32 keys only works for non negative key values. Also NaN and inf values give unexpected results
use std::num::NonZeroU32;
use bytemuck::bytes_of;
use wgpu::util::DeviceExt;
use wgpu_sort::{utils::{download_buffer, guess_workgroup_size, upload_to_buffer}, GPUSorter, HISTO_BLOCK_KVS};
#[pollster::main]
async fn main(){
let instance = wgpu::Instance::new(wgpu::InstanceDescriptor::default());
let adapter = wgpu::util::initialize_adapter_from_env_or_default(&instance, None)
.await
.unwrap();
let (device, queue) = adapter
.request_device(
&wgpu::DeviceDescriptor {
required_features: wgpu::Features::empty(),
required_limits: wgpu::Limits::default(),
label: None,
},
None,
)
.await
.unwrap();
let subgroup_size = guess_workgroup_size(&device, &queue).await.expect("could not find a valid subgroup size");
println!("using subgroup size {subgroup_size}");
let sorter = GPUSorter::new(&device, subgroup_size);
let n = 10;
let sort_buffers = sorter.create_sort_buffers(&device, NonZeroU32::new(n).unwrap());
let keys_scrambled: Vec<f32> = (1..=n).map(|v| 1./v as f32).collect();
let values_scrambled:Vec<u32> = (1..=n).collect();
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("GPURSSorter test_sort"),
});
upload_to_buffer(
&mut encoder,
&sort_buffers.keys(),
&device,
keys_scrambled.as_slice(),
);
upload_to_buffer(
&mut encoder,
&sort_buffers.values(),
&device,
values_scrambled.as_slice(),
);
println!("before: {:?}",keys_scrambled.iter().zip(values_scrambled.iter()).collect::<Vec<(_,_)>>());
// round to next larger multiple of HISTO_BLOCK_KVS
let num_wg = (n + HISTO_BLOCK_KVS- 1)/HISTO_BLOCK_KVS;
let dispatch_indirect = wgpu::util::DispatchIndirectArgs{
x: num_wg,
y: 1,
z: 1
};
queue.write_buffer(sort_buffers.state_buffer(), 0, bytes_of(&n));
let dispatch_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor{
label: Some("dispatch indirect buffer"),
contents: dispatch_indirect.as_bytes(),
usage: wgpu::BufferUsages::INDIRECT,
});
sorter.sort_indirect(&mut encoder, &sort_buffers,&dispatch_buffer);
// wait for sorter to fininsh
let idx = queue.submit([encoder.finish()]);
device.poll(wgpu::Maintain::WaitForSubmissionIndex(idx));
// keys buffer has padding at the end
// so we only download the "valid" data
let keys_sorted = download_buffer::<f32>(
&sort_buffers.keys(),
&device,
&queue,
0..sort_buffers.keys_valid_size()
)
.await;
let value_sorted = download_buffer::<u32>(
&sort_buffers.values(),
&device,
&queue,
..
)
.await;
println!("after: {:?}",keys_sorted.iter().zip(value_sorted.iter()).collect::<Vec<(_,_)>>());
}