-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(vulkan): add arithmetic shader (#193)
* add push constants * add add_inplace * add arithmetic * add test cases for arithmetics * fix export
- Loading branch information
1 parent
7fc4af6
commit 76febc1
Showing
6 changed files
with
192 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
#![allow(dead_code)] | ||
|
||
mod push_constants; | ||
pub mod vulkan_device; | ||
pub mod vulkan_tensor; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
use vulkano::buffer::BufferContents; | ||
|
||
#[derive(BufferContents)] | ||
#[repr(C)] | ||
pub struct ArithmeticPushConstants { | ||
pub n_elms: u32, | ||
pub op: u32, | ||
pub use_scalar_rhs: u32, | ||
pub scalar_rhs: f32, | ||
} |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
#version 450 | ||
|
||
layout(local_size_x = 32) in; | ||
|
||
layout(set = 0, binding = 0) buffer InputBufferA { | ||
float bufA[]; | ||
}; | ||
|
||
layout(set = 0, binding = 1) buffer InputBufferB { | ||
float bufB[]; | ||
}; | ||
|
||
layout(push_constant) uniform PushConstants { | ||
uint nElems; | ||
uint op; | ||
uint use_scalar_rhs; | ||
float scalar_rhs; | ||
} pcs; | ||
|
||
const int OP_ADD = 43; | ||
const int OP_SUB = 45; | ||
const int OP_MUL = 42; | ||
const int OP_DIV = 47; | ||
|
||
void main() { | ||
uint idxA = gl_GlobalInvocationID.x; | ||
|
||
if (idxA >= pcs.nElems) { | ||
return; | ||
} | ||
|
||
float rhs = 0.0; | ||
if (pcs.use_scalar_rhs > 0) { | ||
rhs = pcs.scalar_rhs; | ||
} else { | ||
uint idxB = idxA % bufB.length(); | ||
rhs = bufB[idxB]; | ||
} | ||
|
||
switch (pcs.op) { | ||
case OP_ADD: | ||
bufA[idxA] += rhs; | ||
break; | ||
case OP_SUB: | ||
bufA[idxA] -= rhs; | ||
break; | ||
case OP_MUL: | ||
bufB[idxA] *= rhs; | ||
break; | ||
case OP_DIV: | ||
bufB[idxA] /= rhs; | ||
break; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters