Skip to content

Commit

Permalink
feat(wgpu): add to_dtype kernel (#906)
Browse files Browse the repository at this point in the history
* feat(wgpu): add to_dtype kernel

* fix: add WebGPUNativeType

* style: clippy fix

---------

Co-authored-by: Corey Lowman <clowman1993@gmail.com>
  • Loading branch information
DonIsaac and coreylowman committed Jan 25, 2024
1 parent e04dd4f commit 4722a99
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 20 deletions.
30 changes: 30 additions & 0 deletions dfdx-core/src/tensor/webgpu/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,36 @@ impl Webgpu {
pub(crate) fn get_shader_module(&self, name: TypeId) -> Option<Arc<ShaderModule>> {
self.cs_cache.read().get(&name).cloned()
}
/// Submit a command buffer to the GPU.
///
/// Note: Does not block until completion. If you need this, use
/// `self.dev.poll(Maintain::WaitForSubmissionIndex(idx))` using the
/// returned [`wgpu::SubmissionIndex`]
pub(crate) fn submit_commands<F>(
&self,
label: Option<&str>,
command_builder: F,
) -> wgpu::SubmissionIndex
where
F: FnOnce(&mut wgpu::CommandEncoder),
{
let mut encoder = self
.dev
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: label.clone(),
});

if let Some(label) = label {
encoder.push_debug_group(label);
}
command_builder(&mut encoder);
if labe.is_some() {
encoder.pop_debug_group();
}

let cmd = [encoder.finish()];
self.queue.submit(cmd)
}

// #[allow(unused)]
// pub(crate) unsafe fn get_workspace<E>(&self, len: usize) -> Result<MutexGuard<Buffer>, Error> {
Expand Down
2 changes: 2 additions & 0 deletions dfdx-core/src/tensor/webgpu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
mod allocate;
mod device;
mod types;

pub use device::Buffer;
pub use device::Webgpu;
pub use types::*;

#[cfg(test)]
mod tests {
Expand Down
56 changes: 56 additions & 0 deletions dfdx-core/src/tensor/webgpu/types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
use crate::shapes::Unit;

/// A primitive data type natively supported by WebGPU.
///
/// See: https://www.w3.org/TR/WGSL/#types
///
/// todo: support packed types
pub trait WebgpuNativeType: Unit {
/// Name of the data type in WGSL.
const NAME: &'static str;
}

macro_rules! webgpu_type {
($RustTy:ty) => {
impl WebgpuNativeType for $RustTy {
const NAME: &'static str = stringify!($RustTy);
}
};
($RustTy:ty, $WgpuTy:expr) => {
impl WebgpuNativeType for $RustTy {
const NAME: &'static str = $WgpuTy;
}
};
}

/*
see:
- https://docs.rs/wgpu/latest/wgpu/struct.Features.html#associatedconstant.SHADER_F16
- https://docs.rs/wgpu/latest/wgpu/struct.Features.html#associatedconstant.SHADER_F64
- https://docs.rs/wgpu/latest/wgpu/struct.Features.html#associatedconstant.SHADER_I16
*/
#[cfg(feature = "f16")]
webgpu_type!(half::f16, "f16");
webgpu_type!(f32);
// todo: only enable when f64 feature is enabled
#[cfg(feature = "f64")]
webgpu_type!(f64);

#[cfg(feature = "i16")]
webgpu_type!(i16);
webgpu_type!(i32);

webgpu_type!(u32);
webgpu_type!(bool);

pub(crate) trait HasGlslType {
const TYPE: &'static str;
}

impl HasGlslType for f32 {
const TYPE: &'static str = "float";
}

impl HasGlslType for f64 {
const TYPE: &'static str = "double";
}
16 changes: 16 additions & 0 deletions dfdx-core/src/tensor_ops/to_dtype/to_dtype.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
alias T = __SRC__;
alias U = __DST__;

@group(0) @binding(0)
var<storage, read> in: array<T>;

@group(0) @binding(1)
var<storage, read_write> out: array<U>;

@compute @workgroup_size(1, 1, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>
) {
let i = global_id.x;
out[i] = U(in[i]);
}
99 changes: 96 additions & 3 deletions dfdx-core/src/tensor_ops/to_dtype/webgpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,102 @@
use crate::prelude::{Unit, Webgpu};
use crate::{
prelude::Storage,
tensor::webgpu::{Webgpu, WebgpuNativeType},
tensor_ops::utilities::webgpu_kernels::webgpu_params,
};
use num_traits::AsPrimitive;
use wgpu;

impl<E1: Unit, E2: Unit> super::ToDtypeKernel<E1, E2> for Webgpu {
/// kernel template
const KERNEL: &'static str = include_str!("./to_dtype.wgsl");

const LAYOUT_DESC: wgpu::BindGroupLayoutDescriptor = wgpu::BindGroupLayoutDescriptor {
label: Some("to-dtype"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
};

impl<E1: WebgpuNativeType + AsPrimitive<E2>, E2: WebgpuNativeType> super::ToDtypeKernel<E1, E2>
for Webgpu
{
fn forward<S: crate::prelude::Shape>(
inp: crate::prelude::Tensor<S, E1, Self>,
) -> Result<crate::prelude::Tensor<S, E2, Self>, crate::prelude::Error> {
todo!()
let module_name = std::format!("convert_{}_to_{}", E1::NAME, E2::NAME);
let label = Some(module_name.as_str());
let device = inp.device;

let layout = device.dev.create_bind_group_layout(&LAYOUT_DESC);
let shader_source: String = KERNEL
.replace("__SRC__", E1::NAME)
.replace("__DST__", E2::NAME);

// TODO: support WGSL shaders in device shader cache
let source = wgpu::ShaderSource::Wgsl(shader_source.into());
let shader_module = device
.dev
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(shader_name),
source,
});
let pipeline_layout = device
.dev
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: label.clone(),
bind_group_layouts: layouts,
// todo: these are useful and we should use them if the adapter supports them
push_constant_ranges: &push_constant_ranges,
});

let pipeline = device
.dev
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: label.clone(),
layout: Some(&pipeline_layout),
module: &shader_module,
entry_point: fn_name,
});

let numel = inp.shape.num_elements();
let shape = inp.shape;
let strides = shape.strides();
let output = unsafe { device.alloc_empty::<E2>(numel) }?;

let params: wgpu::BindGroup = webgpu_params!(device, pipeline; inp.data, output);

let _idx = device.submit_commands(label.clone(), |encoder| {
let (x, y, z) = *work_groups;
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: label.clone(),
..Default::default()
});
// TODO: should this be called before the pass, as the pass is created, or before submission?
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, &params, &[]);
pass.dispatch_workgroups(numel as u32, 1, 1);
});

// note: no need to sync here, buffer can remain on the gpu until to_array or to_vec gets called,
// and those functions sync the device before mapping the buffer
Ok(device.build_tensor(shape, strides, output))
}
}
45 changes: 28 additions & 17 deletions dfdx-core/src/tensor_ops/utilities/webgpu_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,33 @@ use crate::{
use core::any::TypeId;
use std::{borrow::Cow, marker::PhantomData, sync::Arc, vec::Vec};

use wgpu::{
BindingType, BufferBindingType, ComputePipelineDescriptor, Device, PipelineLayout, ShaderStages,
};

/// Creates a [`BindGroup`] for a pipeline from a set of [`wgpu::BindingResource`]s.
macro_rules! webgpu_params {
($self:expr, $pipeline:expr; $($x:expr),+ $(,)? ) => {
{
let bindings = [$($x.as_entire_binding()),+];
let entries: Vec<_> = bindings
.into_iter()
.enumerate()
.map(|(i, binding)| wgpu::BindGroupEntry {
binding: i as u32,
resource: binding,
})
.collect();
$self.dev.create_bind_group(&::wgpu::BindGroupDescriptor {
label: None,
layout: &($pipeline).get_bind_group_layout(0),
entries: &entries
})
}
}
}
pub(crate) use webgpu_params;

pub(crate) trait UnaryOpWebgpuKernel<E> {
const DF_USES_FX: bool;
const HAS_CONST_DF: bool;
Expand Down Expand Up @@ -49,6 +76,7 @@ macro_rules! webgpu_unary {
}
};
}
pub(crate) use webgpu_unary;

/// Zero-sized marker type for forward pass TypeId
#[derive(Debug, Default)]
Expand All @@ -62,23 +90,6 @@ pub(crate) struct Backward<E: Dtype, K> {
_phantom: PhantomData<(E, K)>,
}

pub(crate) trait HasGlslType {
const TYPE: &'static str;
}

impl HasGlslType for f32 {
const TYPE: &'static str = "float";
}

impl HasGlslType for f64 {
const TYPE: &'static str = "double";
}

pub(crate) use webgpu_unary;
use wgpu::{
BindingType, BufferBindingType, ComputePipelineDescriptor, Device, PipelineLayout, ShaderStages,
};

impl<E: Dtype + HasGlslType, K: UnaryOpWebgpuKernel<E> + 'static> UnaryKernel<K, E> for Webgpu {
const BACKWARD_WITHOUT_INP: bool = K::DF_USES_FX;
const BACKWARD_WITHOUT_DATA: bool = K::HAS_CONST_DF;
Expand Down

0 comments on commit 4722a99

Please sign in to comment.