Skip to content

Commit

Permalink
feat(vulkan): add arithmetic shader (#193)
Browse files Browse the repository at this point in the history
* add push constants

* add add_inplace

* add arithmetic

* add test cases for arithmetics

* fix export
  • Loading branch information
flaneur2020 committed May 19, 2024
1 parent 7fc4af6 commit 76febc1
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 49 deletions.
1 change: 1 addition & 0 deletions crabml-vulkan/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![allow(dead_code)]

mod push_constants;
pub mod vulkan_device;
pub mod vulkan_tensor;
10 changes: 10 additions & 0 deletions crabml-vulkan/src/push_constants.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use vulkano::buffer::BufferContents;

#[derive(BufferContents)]
#[repr(C)]
pub struct ArithmeticPushConstants {
pub n_elms: u32,
pub op: u32,
pub use_scalar_rhs: u32,
pub scalar_rhs: f32,
}
16 changes: 0 additions & 16 deletions crabml-vulkan/src/shaders/add.comp

This file was deleted.

54 changes: 54 additions & 0 deletions crabml-vulkan/src/shaders/arithmetic.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#version 450

layout(local_size_x = 32) in;

layout(set = 0, binding = 0) buffer InputBufferA {
float bufA[];
};

layout(set = 0, binding = 1) buffer InputBufferB {
float bufB[];
};

layout(push_constant) uniform PushConstants {
uint nElems;
uint op;
uint use_scalar_rhs;
float scalar_rhs;
} pcs;

const int OP_ADD = 43;
const int OP_SUB = 45;
const int OP_MUL = 42;
const int OP_DIV = 47;

void main() {
uint idxA = gl_GlobalInvocationID.x;

if (idxA >= pcs.nElems) {
return;
}

float rhs = 0.0;
if (pcs.use_scalar_rhs > 0) {
rhs = pcs.scalar_rhs;
} else {
uint idxB = idxA % bufB.length();
rhs = bufB[idxB];
}

switch (pcs.op) {
case OP_ADD:
bufA[idxA] += rhs;
break;
case OP_SUB:
bufA[idxA] -= rhs;
break;
case OP_MUL:
bufB[idxA] *= rhs;
break;
case OP_DIV:
bufB[idxA] /= rhs;
break;
}
}
20 changes: 14 additions & 6 deletions crabml-vulkan/src/vulkan_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::sync::Arc;

use bytemuck::NoUninit;
use vulkano::buffer::Buffer;
use vulkano::buffer::BufferContents;
use vulkano::buffer::BufferCreateInfo;
use vulkano::buffer::BufferUsage;
use vulkano::buffer::Subbuffer;
Expand Down Expand Up @@ -94,12 +95,16 @@ impl VulkanTensorDevice {
}

fn load_shaders(&mut self) {
mod add {
vulkano_shaders::shader! { ty: "compute", path: "./src/shaders/add.comp" }
mod arithmetic_shader {
vulkano_shaders::shader! { ty: "compute", path: "./src/shaders/arithmetic.comp" }
}

let device = self.inner.device.clone();
let entry_points = [("add", load_shader_entry_point!(add, device.clone(), "main"))];
let entry_points = [(
"arithmetic",
load_shader_entry_point!(arithmetic_shader, device.clone(), "main"),
)];

for (name, entry_point) in entry_points.into_iter() {
self.inner.load_compute_pipeline(name, entry_point);
}
Expand Down Expand Up @@ -277,10 +282,11 @@ impl VulkanTensorDeviceInner {
.for_each(|(s, d)| *d = *s);
}

pub fn dispatch_compute(
pub fn dispatch_compute<Pc: BufferContents>(
&self,
pipeline_name: &str,
buffers: Vec<Subbuffer<[u8]>>,
push_constants: Pc,
dispatch_group: [u32; 3],
) {
let pipeline = self.pipelines.get(pipeline_name).unwrap();
Expand Down Expand Up @@ -318,9 +324,11 @@ impl VulkanTensorDeviceInner {
0,
set,
)
.unwrap()
.dispatch(dispatch_group)
.unwrap();
builder
.push_constants(pipeline.layout().clone(), 0, push_constants)
.unwrap();
builder.dispatch(dispatch_group).unwrap();
builder.build().unwrap()
};

Expand Down
140 changes: 113 additions & 27 deletions crabml-vulkan/src/vulkan_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crabml::tensor::Tensor;
use crabml::tensor::TensorStrider;

use super::vulkan_device::VulkanTensorDeviceRef;
use crate::push_constants::ArithmeticPushConstants;

#[derive(Clone)]
pub struct VulkanTensor {
Expand Down Expand Up @@ -38,26 +39,6 @@ impl VulkanTensor {
name: None,
})
}

pub fn export(&self, dst: &mut [f32]) -> Result<()> {
let buf_size = std::mem::size_of_val(dst);
if buf_size > self.device.opts.staging_buf_bytes {
return Err((
ErrorKind::TensorError,
format!(
"buffer size exceeded staging buffer limit: {}, got: {}",
self.device.opts.staging_buf_bytes, buf_size,
),
)
.into());
}

let dst_bytes = bytemuck::cast_slice_mut(dst);
self.device
.inner
.copy_device_buffer_to_cpu(self.buf.clone(), dst_bytes);
Ok(())
}
}

impl Tensor for VulkanTensor {
Expand Down Expand Up @@ -120,8 +101,24 @@ impl Tensor for VulkanTensor {
todo!()
}

fn export(&self, buf: &mut [f32]) -> Result<()> {
todo!()
fn export(&self, dst: &mut [f32]) -> Result<()> {
let buf_size = std::mem::size_of_val(dst);
if buf_size > self.device.opts.staging_buf_bytes {
return Err((
ErrorKind::TensorError,
format!(
"buffer size exceeded staging buffer limit: {}, got: {}",
self.device.opts.staging_buf_bytes, buf_size,
),
)
.into());
}

let dst_bytes = bytemuck::cast_slice_mut(dst);
self.device
.inner
.copy_device_buffer_to_cpu(self.buf.clone(), dst_bytes);
Ok(())
}

fn dup(&self) -> Result<Self> {
Expand Down Expand Up @@ -154,27 +151,77 @@ impl Tensor for VulkanTensor {
}

fn mul_inplace(self, rhs: &Self) -> Result<Self> {
todo!()
assert!(self.strider.is_contiguous());
assert!(rhs.strider.is_contiguous());

let n_elms = self.strider.len() as u32;
let bufs = vec![self.buf.clone(), rhs.buf.clone()];
let pcs = ArithmeticPushConstants {
n_elms,
op: '*' as u32,
use_scalar_rhs: 0,
scalar_rhs: 0.0,
};
let dispatches = [n_elms / 32 + 1, 1, 1];
self.device
.inner
.dispatch_compute("arithmetic", bufs, pcs, dispatches);
Ok(self)
}

fn add_inplace(self, rhs: &Self) -> Result<Self> {
assert!(self.strider.is_contiguous());
assert!(rhs.strider.is_contiguous());

// TODO: pass n_elm as meta
let n_elms = self.strider.len() as u32;
let bufs = vec![self.buf.clone(), rhs.buf.clone()];
let pcs = ArithmeticPushConstants {
n_elms,
op: '+' as u32,
use_scalar_rhs: 0,
scalar_rhs: 0.0,
};
let dispatches = [n_elms / 32 + 1, 1, 1];
self.device.inner.dispatch_compute("add", bufs, dispatches);
self.device
.inner
.dispatch_compute("arithmetic", bufs, pcs, dispatches);
Ok(self)
}

fn div_scalar_inplace(self, rhs: f32) -> Result<Self> {
todo!()
assert!(self.strider.is_contiguous());

let n_elms = self.strider.len() as u32;
let bufs = vec![self.buf.clone(), self.buf.clone()];
let pcs = ArithmeticPushConstants {
n_elms,
op: '/' as u32,
use_scalar_rhs: 1,
scalar_rhs: rhs,
};
let dispatches = [n_elms / 32 + 1, 1, 1];
self.device
.inner
.dispatch_compute("arithmetic", bufs, pcs, dispatches);
Ok(self)
}

fn scale_inplace(self, rhs: f32) -> Result<Self> {
todo!()
assert!(self.strider.is_contiguous());

let n_elms = self.strider.len() as u32;
let bufs = vec![self.buf.clone(), self.buf.clone()];
let pcs = ArithmeticPushConstants {
n_elms,
op: '*' as u32,
use_scalar_rhs: 1,
scalar_rhs: rhs,
};
let dispatches = [n_elms / 32 + 1, 1, 1];
self.device
.inner
.dispatch_compute("arithmetic", bufs, pcs, dispatches);
Ok(self)
}

fn matmul_vec(&self, y: &Self) -> Result<Self> {
Expand Down Expand Up @@ -217,4 +264,43 @@ mod tests {
]);
Ok(())
}

#[test]
fn test_scale_inplace() -> Result<()> {
let d = VulkanTensorDevice::new(VulkanTensorDeviceOptions::default());

let buf1 = (0..34).map(|v| v as f32).collect::<Vec<_>>();

let t1 = VulkanTensor::new(&buf1, &[34], d.clone()).unwrap();

let t1 = t1.scale_inplace(2.0).unwrap();
let mut bufo = vec![0.0; 34];
t1.export(&mut bufo)?;

assert_eq!(bufo, vec![
0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0,
30.0, 32.0, 34.0, 36.0, 38.0, 40.0, 42.0, 44.0, 46.0, 48.0, 50.0, 52.0, 54.0, 56.0,
58.0, 60.0, 62.0, 64.0, 66.0
]);
Ok(())
}

#[test]
fn test_div_scalar() -> Result<()> {
let d = VulkanTensorDevice::new(VulkanTensorDeviceOptions::default());

let buf1 = (0..32).map(|v| v as f32).collect::<Vec<_>>();

let t1 = VulkanTensor::new(&buf1, &[32], d.clone()).unwrap();

let t1 = t1.div_scalar_inplace(2.0).unwrap();
let mut bufo = vec![0.0; 32];
t1.export(&mut bufo)?;

assert_eq!(bufo, vec![
0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0,
8.5, 9.0, 9.5, 10.0, 10.5, 11.0, 11.5, 12.0, 12.5, 13.0, 13.5, 14.0, 14.5, 15.0, 15.5
]);
Ok(())
}
}

0 comments on commit 76febc1

Please sign in to comment.