Skip to content

Commit

Permalink
Add NPU support for wasi-nn WinML backend.
Browse files Browse the repository at this point in the history
This change adds support for NPU (Neural Processing Unit) to the wasi-nn
WinML backend. Since NPU support in DirectML is still in developer
preview, only a subset of learning models are supported.
  • Loading branch information
jianjunz committed Jul 1, 2024
1 parent 8fc4186 commit fabbc56
Show file tree
Hide file tree
Showing 3 changed files with 215 additions and 47 deletions.
24 changes: 24 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 14 additions & 1 deletion crates/wasi-nn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,20 @@ ort = { version = "2.0.0-rc.2", default-features = false, features = [

[target.'cfg(windows)'.dependencies.windows]
version = "0.52"
features = ["AI_MachineLearning", "Storage_Streams", "Foundation_Collections"]
features = [
"AI_MachineLearning",
"Storage_Streams",
"Foundation_Collections",
# For Int64 input support.
"implement",
# Following 6 features are needed for creating a LearningModelDevice from NPU.
"Win32_Foundation",
"Win32_Graphics_Direct3D",
"Win32_Graphics_Direct3D12",
"Win32_Graphics_Dxgi",
"Win32_Graphics_DXCore",
"Win32_System_WinRT_ML",
]
optional = true

[build-dependencies]
Expand Down
223 changes: 177 additions & 46 deletions crates/wasi-nn/src/backend/winml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,26 @@ use crate::wit::{ExecutionTarget, GraphEncoding, Tensor, TensorType};
use crate::{ExecutionContext, Graph};
use std::{fs::File, io::Read, mem::size_of, path::Path};
use windows::core::{ComInterface, HSTRING};
use windows::Foundation::Collections::IVectorView;
use windows::Foundation::Collections::{IVectorView, IIterable};
use windows::Storage::Streams::{
DataWriter, InMemoryRandomAccessStream, RandomAccessStreamReference,
};
use windows::Win32::Graphics::DXCore::{
DXCoreCreateAdapterFactory, IDXCoreAdapter, IDXCoreAdapterFactory, IDXCoreAdapterList,
DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE, DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS,
};
use windows::Win32::Graphics::{
Direct3D::D3D_FEATURE_LEVEL_1_0_CORE,
Direct3D12::{
D3D12CreateDevice, ID3D12CommandQueue, ID3D12Device, D3D12_COMMAND_LIST_TYPE_COMPUTE,
D3D12_COMMAND_QUEUE_DESC, D3D12_COMMAND_QUEUE_FLAG_NONE,
},
};
use windows::Win32::System::WinRT::ML::ILearningModelDeviceFactoryNative;
use windows::AI::MachineLearning::{
ILearningModelFeatureDescriptor, LearningModel, LearningModelBinding, LearningModelDevice,
LearningModelDeviceKind, LearningModelEvaluationResult, LearningModelSession,
TensorFeatureDescriptor, TensorFloat,
TensorFeatureDescriptor, TensorFloat, TensorFloat16Bit, TensorInt64Bit, TensorKind,
};

#[derive(Default)]
Expand All @@ -45,12 +57,66 @@ impl BackendInner for WinMLBackend {
let model = LearningModel::LoadFromStream(&RandomAccessStreamReference::CreateFromStream(
&model_stream,
)?)?;
let device_kind = match target {
ExecutionTarget::Cpu => LearningModelDeviceKind::Cpu,
ExecutionTarget::Gpu => LearningModelDeviceKind::DirectX,
ExecutionTarget::Tpu => unimplemented!(),
let device = match target {
ExecutionTarget::Cpu => LearningModelDevice::Create(LearningModelDeviceKind::Cpu),
ExecutionTarget::Gpu => LearningModelDevice::Create(LearningModelDeviceKind::DirectX),
ExecutionTarget::Tpu => unsafe {
// Enumerate adapters with DXCore APIs so MCDM (Microsoft Compute Driver Model) devices can be found.
let dx_adapter_factory: IDXCoreAdapterFactory = DXCoreCreateAdapterFactory()?;
let adapter_list =
dx_adapter_factory.CreateAdapterList::<IDXCoreAdapterList>(&[
DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE,
])?;
let mut selected_device: Option<IDXCoreAdapter> = None;
for i in 0..adapter_list.GetAdapterCount() {
let adapter = adapter_list.GetAdapter::<IDXCoreAdapter>(i)?;
// Select a compute only device. DXCORE_ADAPTER_ATTRIBUTE_D3D12_GENERIC_ML looks more suitable here, but it's defined in DirectX headers.
if adapter.IsAttributeSupported(&DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE)
&& !adapter.IsAttributeSupported(&DXCORE_ADAPTER_ATTRIBUTE_D3D12_GRAPHICS)
{
selected_device = Some(adapter);
break;
}
}
if selected_device.is_none() {
return Err(BackendError::BackendAccess(anyhow::Error::msg(
"NPU is not available on this device.",
)));
}

let mut d3d12_device: Option<ID3D12Device> = None;
D3D12CreateDevice(
&selected_device.unwrap(),
D3D_FEATURE_LEVEL_1_0_CORE,
&mut d3d12_device,
)?;
if d3d12_device.is_none() {
return Err(BackendError::BackendAccess(anyhow::Error::msg(
"Failed to create D3D12 device.",
)));
}
let d3d12_command_queue_desc: D3D12_COMMAND_QUEUE_DESC = D3D12_COMMAND_QUEUE_DESC {
Type: D3D12_COMMAND_LIST_TYPE_COMPUTE,
Flags: D3D12_COMMAND_QUEUE_FLAG_NONE,
NodeMask: 0,
Priority: 0,
};
let d3d12_command_queue = d3d12_device
.unwrap()
.CreateCommandQueue::<ID3D12CommandQueue>(&d3d12_command_queue_desc)?;
let factory = windows::core::factory::<
LearningModelDevice,
ILearningModelDeviceFactoryNative,
>()?;
factory
.CreateFromD3D12CommandQueue(&d3d12_command_queue)?
.cast::<LearningModelDevice>()
},
};
let graph = WinMLGraph {
model,
device: device?,
};
let graph = WinMLGraph { model, device_kind };

let box_: Box<dyn BackendGraph> = Box::new(graph);
Ok(box_.into())
Expand All @@ -74,16 +140,16 @@ impl BackendFromDir for WinMLBackend {

struct WinMLGraph {
model: LearningModel,
device_kind: LearningModelDeviceKind,
device: LearningModelDevice,
}

unsafe impl Send for WinMLGraph {}
unsafe impl Sync for WinMLGraph {}

impl BackendGraph for WinMLGraph {
fn init_execution_context(&self) -> Result<ExecutionContext, BackendError> {
let device = LearningModelDevice::Create(self.device_kind.clone())?;
let session = LearningModelSession::CreateFromModelOnDevice(&self.model, &device)?;
let session =
LearningModelSession::CreateFromModelOnDevice(&self.model, &self.device).unwrap();
let box_: Box<dyn BackendExecutionContext> = Box::new(WinMLExecutionContext::new(session));
Ok(box_.into())
}
Expand Down Expand Up @@ -136,32 +202,58 @@ impl WinMLExecutionContext {

impl BackendExecutionContext for WinMLExecutionContext {
fn set_input(&mut self, id: Id, tensor: &Tensor) -> Result<(), BackendError> {
// TODO: Clear previous bindings when needed.

let input_features = self.session.Model()?.InputFeatures()?;
let index = self.find(id, &input_features)?;
let input = input_features.GetAt(index)?;

// TODO: Support other tensor types. Only FP32 is supported right now.
// TODO: Support other tensor types. Only FP16, FP32 and I64 are
// supported right now.
match tensor.ty {
crate::wit::types::TensorType::Fp32 => {}
_ => unimplemented!(),
}
crate::wit::types::TensorType::Fp16 => unsafe {
let data = std::slice::from_raw_parts(
tensor.data.as_ptr() as *const f32,
tensor.data.len() / size_of::<f32>(),
);

// TODO: this is quite unsafe and probably incorrect--will the slice
// still be around by the time the binding is used?!
let data = unsafe {
std::slice::from_raw_parts(
tensor.data.as_ptr() as *const f32,
tensor.data.len() / size_of::<f32>(),
)
};
// TODO: this is quite unsafe and probably incorrect--will the
// slice still be around by the time the binding is used?!
self.binding.Bind(
&input.Name()?,
&TensorFloat16Bit::CreateFromArray(
&input.cast::<TensorFeatureDescriptor>()?.Shape()?,
data,
)?,
)?;
},
crate::wit::types::TensorType::Fp32 => unsafe {
let data = std::slice::from_raw_parts(
tensor.data.as_ptr() as *const f32,
tensor.data.len() / size_of::<f32>(),
);

self.binding.Bind(
&input.Name()?,
&TensorFloat::CreateFromArray(
&input.cast::<TensorFeatureDescriptor>()?.Shape()?,
data,
)?,
)?;
self.binding.Bind(
&input.Name()?,
&TensorFloat::CreateFromArray(
&input.cast::<TensorFeatureDescriptor>()?.Shape()?,
data,
)?,
)?;
},
crate::wit::types::TensorType::I64 => unsafe {
let data = std::slice::from_raw_parts(
tensor.data.as_ptr() as *const i64,
tensor.data.len() / size_of::<i64>(),
);
let dim: Vec<i64> = tensor.dimensions.iter().map(|&x| x as i64).collect();
let shape: IIterable<i64> = IIterable::<i64>::try_from(dim)?;
let tensor = TensorInt64Bit::CreateFromArray(&shape, data)?;

self.binding.Bind(&input.Name()?, &tensor)?;
},
_ => unimplemented!(),
}

Ok(())
}
Expand All @@ -175,23 +267,62 @@ impl BackendExecutionContext for WinMLExecutionContext {
if let Some(result) = &self.result {
let output_features = self.session.Model()?.OutputFeatures()?;
let index = self.find(id, &output_features)?;
let output = output_features.GetAt(index)?;
// TODO: this only handles FP32!
let tensor = result
.Outputs()?
.Lookup(&output.Name()?)?
.cast::<TensorFloat>()?;
let dimensions = dimensions_as_u32(&tensor.Shape()?)?;
let view = tensor.GetAsVectorView()?;
let mut data = Vec::with_capacity(view.Size()? as usize * size_of::<f32>());
for f in view.into_iter() {
data.extend(f.to_le_bytes());
}
Ok(Tensor {
ty: TensorType::Fp32,
dimensions,
data,
})
let output_feature = output_features.GetAt(index)?;
let tensor_kind = match output_feature.Kind()? {
windows::AI::MachineLearning::LearningModelFeatureKind::Tensor => output_feature
.cast::<TensorFeatureDescriptor>()?
.TensorKind()?,
_ => unimplemented!(),
};
// TODO: this only handles FP16, FP32 and I64!
let output_inspectable = result.Outputs()?.Lookup(&output_feature.Name()?)?;
let tensor = match tensor_kind {
TensorKind::Float16 => {
let output_tensor = output_inspectable.cast::<TensorFloat16Bit>()?;
let dimensions = dimensions_as_u32(&output_tensor.Shape()?)?;
let view = output_tensor.GetAsVectorView()?;
// TODO: Move to f16 when it's available in stable.
let mut data = Vec::with_capacity(view.Size()? as usize * size_of::<f32>());
for f in view.into_iter() {
data.extend(f.to_le_bytes());
}
Tensor {
ty: TensorType::Fp16,
dimensions,
data,
}
}
TensorKind::Float => {
let output_tensor = output_inspectable.cast::<TensorFloat>()?;
let dimensions = dimensions_as_u32(&output_tensor.Shape()?)?;
let view = output_tensor.GetAsVectorView()?;
let mut data = Vec::with_capacity(view.Size()? as usize * size_of::<f32>());
for f in view.into_iter() {
data.extend(f.to_le_bytes());
}
Tensor {
ty: TensorType::Fp32,
dimensions,
data,
}
}
TensorKind::Int64 => {
let output_tensor = output_inspectable.cast::<TensorInt64Bit>()?;
let dimensions = dimensions_as_u32(&output_tensor.Shape()?)?;
let view = output_tensor.GetAsVectorView()?;
let mut data = Vec::with_capacity(view.Size()? as usize * size_of::<i64>());
for f in view.into_iter() {
data.extend(f.to_le_bytes());
}
Tensor {
ty: TensorType::I64,
dimensions,
data,
}
}
_ => unimplemented!(),
};
Ok(tensor)
} else {
return Err(BackendError::BackendAccess(anyhow::Error::msg(
"Output is not ready.",
Expand Down

0 comments on commit fabbc56

Please sign in to comment.