Skip to content

Commit

Permalink
Enable unit test for wasi-nn WinML backend.
Browse files Browse the repository at this point in the history
This test was disabled because GitHub Actions Windows Server image
doesn't have desktop experience included. But it looks like we can have
a standalone WinML binary downloaded from ONNX Runtime project.

Wasi-nn WinML backend and ONNX Runtime backend now share the same test
code as they accept the same input, and they are expected to produce the
same result.

This change also make wasi-nn WinML backend as a default feature.

prtest:full
  • Loading branch information
jianjunz committed Jun 6, 2024
1 parent 28ee78b commit 0fc6460
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 46 deletions.
12 changes: 12 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,9 @@ jobs:
matrix:
feature: ["openvino", "onnx"]
os: ["ubuntu-latest", "windows-latest"]
include:
- os: windows-latest
feature: winml
name: Test wasi-nn (${{ matrix.feature }}, ${{ matrix.os }})
runs-on: ${{ matrix.os }}
needs: determine
Expand All @@ -696,6 +699,15 @@ jobs:
- uses: abrown/install-openvino-action@v8
if: runner.arch == 'X64'

# Install WinML for testing wasi-nn WinML backend. WinML is only available
# on Windows clients and Windows Server with desktop experience enabled.
# GitHub Actions Window Server image doesn't have desktop experience
# enabled, so we download the standalone library from ONNX Runtime project.
- uses: nuget/setup-nuget@v2
if: (matrix.os == 'windows-latest') && (matrix.feature == 'winml')
- run: nuget install Microsoft.AI.MachineLearning
if: (matrix.os == 'windows-latest') && (matrix.feature == 'winml')

# Install Rust targets.
- run: rustup target add wasm32-wasi

Expand Down
5 changes: 4 additions & 1 deletion crates/test-programs/src/bin/nn_image_classification_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ pub fn main() -> Result<()> {
.expect("the model file to be mapped to the fixture directory");
let graph =
GraphBuilder::new(GraphEncoding::Onnx, ExecutionTarget::CPU).build_from_bytes([&model])?;
let tensor = fs::read("fixture/tensor.bgr")
let tensor = fs::read("fixture/000000062808.rgb")
.expect("the tensor file to be mapped to the fixture directory");
let results = classify(graph, tensor)?;
let top_five = &sort_results(&results)[..5];
// 963 is meat loaf, meatloaf.
// https://github.com/onnx/models/blob/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/synset.txt#L963
assert_eq!(top_five[0].class_id(), 963);
println!("found results, sorted top 5: {:?}", top_five);
Ok(())
}
4 changes: 0 additions & 4 deletions crates/test-programs/src/nn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,8 @@ pub fn classify(graph: Graph, tensor: Vec<u8>) -> Result<Vec<f32>> {
/// placing the match probability for each class at the index for that class
/// (the probability of class `N` is stored at `probabilities[N]`).
pub fn sort_results(probabilities: &[f32]) -> Vec<InferenceResult> {
// It is unclear why the MobileNet output indices are "off by one" but the
// `.skip(1)` below seems necessary to get results that make sense (e.g. 763
// = "revolver" vs 762 = "restaurant").
let mut results: Vec<InferenceResult> = probabilities
.iter()
.skip(1)
.enumerate()
.map(|(c, p)| InferenceResult(c, *p))
.collect();
Expand Down
2 changes: 1 addition & 1 deletion crates/wasi-nn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ wasi-common = { workspace = true, features = ["sync"] }
wasmtime = { workspace = true, features = ["cranelift"] }

[features]
default = ["openvino"]
default = ["openvino", "winml"]
# openvino is available on all platforms, it requires openvino installed.
openvino = ["dep:openvino"]
# onnx is available on all platforms.
Expand Down
55 changes: 17 additions & 38 deletions crates/wasi-nn/src/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@

#[allow(unused_imports)]
use anyhow::{anyhow, Context, Result};
use std::{env, fs, path::Path, path::PathBuf, process::Command, sync::Mutex};
use std::{
env, fs,
path::{Path, PathBuf},
process::Command,
sync::Mutex,
};

#[cfg(all(feature = "winml", target_arch = "x86_64", target_os = "windows"))]
use windows::AI::MachineLearning::{LearningModelDevice, LearningModelDeviceKind};
Expand Down Expand Up @@ -50,13 +55,12 @@ pub fn check() -> Result<()> {
#[cfg(feature = "openvino")]
check_openvino_artifacts_are_available()?;

#[cfg(feature = "onnx")]
#[cfg(any(feature = "onnx", all(feature = "winml", target_os = "windows")))]
check_onnx_artifacts_are_available()?;

#[cfg(all(feature = "winml", target_arch = "x86_64", target_os = "windows"))]
#[cfg(all(feature = "winml", target_os = "windows"))]
{
check_winml_is_available()?;
check_winml_artifacts_are_available()?;
}
Ok(())
}
Expand Down Expand Up @@ -108,7 +112,7 @@ fn check_openvino_artifacts_are_available() -> Result<()> {
Ok(())
}

#[cfg(all(feature = "winml", target_arch = "x86_64", target_os = "windows"))]
#[cfg(all(feature = "winml", target_os = "windows"))]
fn check_winml_is_available() -> Result<()> {
match std::panic::catch_unwind(|| {
println!(
Expand All @@ -121,59 +125,34 @@ fn check_winml_is_available() -> Result<()> {
}
}

#[cfg(feature = "onnx")]
#[cfg(any(feature = "onnx", all(feature = "winml", target_os = "windows")))]
fn check_onnx_artifacts_are_available() -> Result<()> {
let _exclusively_retrieve_artifacts = ARTIFACTS.lock().unwrap();

const OPENVINO_BASE_URL: &str =
"https://github.com/intel/openvino-rs/raw/main/crates/openvino/tests/fixtures/mobilenet";
const ONNX_BASE_URL: &str =
"https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/mobilenet/model/mobilenetv2-7.onnx?download=";
"https://github.com/onnx/models/raw/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/mobilenet/model/mobilenetv2-10.onnx?download=";

let artifacts_dir = artifacts_dir();
if !artifacts_dir.is_dir() {
fs::create_dir(&artifacts_dir)?;
}

for (from, to) in [
(
[OPENVINO_BASE_URL, "tensor-1x224x224x3-f32.bgr"].join("/"),
"tensor.bgr",
),
(ONNX_BASE_URL.to_string(), "model.onnx"),
] {
for (from, to) in [(ONNX_BASE_URL.to_string(), "model.onnx")] {
let local_path = artifacts_dir.join(to);
if !local_path.is_file() {
download(&from, &local_path).with_context(|| "unable to retrieve test artifact")?;
} else {
println!("> using cached artifact: {}", local_path.display())
}
}
Ok(())
}

#[cfg(all(feature = "winml", target_arch = "x86_64", target_os = "windows"))]
fn check_winml_artifacts_are_available() -> Result<()> {
let _exclusively_retrieve_artifacts = ARTIFACTS.lock().unwrap();
let artifacts_dir = artifacts_dir();
if !artifacts_dir.is_dir() {
fs::create_dir(&artifacts_dir)?;
}
const MODEL_URL: &str="https://github.com/onnx/models/raw/5faef4c33eba0395177850e1e31c4a6a9e634c82/vision/classification/mobilenet/model/mobilenetv2-12.onnx";
for (from, to) in [(MODEL_URL, "model.onnx")] {
let local_path = artifacts_dir.join(to);
if !local_path.is_file() {
download(&from, &local_path).with_context(|| "unable to retrieve test artifact")?;
} else {
println!("> using cached artifact: {}", local_path.display())
}
}
// kitten.rgb is converted from https://github.com/microsoft/Windows-Machine-Learning/blob/master/SharedContent/media/kitten_224.png?raw=true.
let tensor_path = env::current_dir()?
// Copy image from source tree to artifact directory.
let image_path = env::current_dir()?
.join("tests")
.join("fixtures")
.join("kitten.rgb");
fs::copy(tensor_path, artifacts_dir.join("kitten.rgb"))?;
.join("000000062808.rgb");
let dest_path = artifacts_dir.join("000000062808.rgb");
fs::copy(&image_path, &dest_path)?;
Ok(())
}

Expand Down
4 changes: 2 additions & 2 deletions crates/wasi-nn/tests/all.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ fn nn_image_classification_named() {
#[cfg_attr(not(all(feature = "winml", target_os = "windows")), ignore)]
#[test]
fn nn_image_classification_winml() {
#[cfg(feature = "winml")]
#[cfg(all(feature = "winml", target_os = "windows"))]
{
let backend = Backend::from(backend::winml::WinMLBackend::default());
run(NN_IMAGE_CLASSIFICATION_WINML, backend, true).unwrap()
run(NN_IMAGE_CLASSIFICATION_ONNX, backend, true).unwrap()
}
}

Expand Down
Binary file added crates/wasi-nn/tests/fixtures/000000062808.rgb
Binary file not shown.
Binary file removed crates/wasi-nn/tests/fixtures/kitten.rgb
Binary file not shown.
9 changes: 9 additions & 0 deletions crates/wasi-nn/tests/fixtures/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
The original image of 000000062808.rgb is 000000062808.jpg from ImageNet
database. It processed by following Python code with
https://github.com/onnx/models/blob/bec48b6a70e5e9042c0badbaafefe4454e072d08/validated/vision/classification/imagenet_preprocess.py

```
image = mxnet.image.imread('000000062808.jpg')
image = preprocess_mxnet(image)
image.asnumpy().tofile('000000062808.rgb')
```

0 comments on commit 0fc6460

Please sign in to comment.