Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add String data type support #58

Merged
merged 4 commits into from
Feb 24, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
onnxruntime.git
Cargo.lock
**/synset.txt

/.idea
39 changes: 39 additions & 0 deletions onnxruntime/examples/print_structure.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//! Display the input and output structure of an ONNX model.
use onnxruntime::environment;
use std::error::Error;

fn main() -> Result<(), Box<dyn Error>> {
// provide path to .onnx model on disk
let path = std::env::args()
.skip(1)
.next()
.expect("Must provide an .onnx file as the first arg");

let environment = environment::Environment::builder()
.with_name("onnx metadata")
.with_log_level(onnxruntime::LoggingLevel::Verbose)
.build()?;

let session = environment
.new_session_builder()?
.with_optimization_level(onnxruntime::GraphOptimizationLevel::Basic)?
.with_model_from_file(path)?;

println!("Inputs:");
for (index, input) in session.inputs.iter().enumerate() {
println!(
" {}:\n name = {}\n type = {:?}\n dimensions = {:?}",
index, input.name, input.input_type, input.dimensions
)
}

println!("Outputs:");
for (index, output) in session.outputs.iter().enumerate() {
println!(
" {}:\n name = {}\n type = {:?}\n dimensions = {:?}",
index, output.name, output.output_type, output.dimensions
);
}

Ok(())
}
13 changes: 13 additions & 0 deletions onnxruntime/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,15 @@ pub enum OrtError {
/// Error occurred when creating CPU memory information
#[error("Failed to get dimensions: {0}")]
CreateCpuMemoryInfo(OrtApiError),
/// Error occurred when creating ONNX tensor
#[error("Failed to create tensor: {0}")]
CreateTensor(OrtApiError),
/// Error occurred when creating ONNX tensor with specific data
#[error("Failed to create tensor with data: {0}")]
CreateTensorWithData(OrtApiError),
/// Error occurred when filling a tensor with string data
#[error("Failed to fill string tensor: {0}")]
FillStringTensor(OrtApiError),
/// Error occurred when checking if ONNX tensor was properly initialized
#[error("Failed to check if tensor: {0}")]
IsTensor(OrtApiError),
Expand Down Expand Up @@ -184,3 +190,10 @@ pub(crate) fn status_to_result(
let status_wrapper: OrtStatusWrapper = status.into();
status_wrapper.into()
}

/// A wrapper around a function on OrtApi that maps the status code into [OrtApiError]
pub(crate) unsafe fn call_ort<F: FnMut(sys::OrtApi) -> *const sys::OrtStatus>(
mut block: F,
) -> std::result::Result<(), OrtApiError> {
marshallpierce marked this conversation as resolved.
Show resolved Hide resolved
status_to_result(block(g_ort()))
marshallpierce marked this conversation as resolved.
Show resolved Hide resolved
}
55 changes: 46 additions & 9 deletions onnxruntime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,8 @@ pub enum TensorElementDataType {
Int32 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt,
/// Signed 64-bit int, equivalent to Rust's `i64`
Int64 = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt,
// /// String, equivalent to Rust's `String`
// String = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt,
/// String, equivalent to Rust's `String`
String = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt,
// /// Boolean, equivalent to Rust's `bool`
// Bool = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt,
// /// 16-bit floating point, equivalent to Rust's `f16`
Expand Down Expand Up @@ -374,9 +374,7 @@ impl Into<sys::ONNXTensorElementDataType> for TensorElementDataType {
Int16 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
Int32 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
Int64 => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
// String => {
// sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING
// }
String => sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
// Bool => {
// sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL
// }
Expand All @@ -402,15 +400,22 @@ impl Into<sys::ONNXTensorElementDataType> for TensorElementDataType {
/// Trait used to map Rust types (for example `f32`) to ONNX types (for example `Float`)
pub trait TypeToTensorElementDataType {
/// Return the ONNX type for a Rust type
fn tensor_element_data_type() -> sys::ONNXTensorElementDataType;
fn tensor_element_data_type() -> TensorElementDataType;

/// If the type is `String`, returns `Some` with utf8 contents, else `None`.
fn try_utf8_bytes(&self) -> Option<&[u8]>;
}

macro_rules! impl_type_trait {
($type_:ty, $variant:ident) => {
impl TypeToTensorElementDataType for $type_ {
fn tensor_element_data_type() -> sys::ONNXTensorElementDataType {
fn tensor_element_data_type() -> TensorElementDataType {
// unsafe { std::mem::transmute(TensorElementDataType::$variant) }
TensorElementDataType::$variant.into()
TensorElementDataType::$variant
}

fn try_utf8_bytes(&self) -> Option<&[u8]> {
None
}
}
};
Expand All @@ -423,7 +428,6 @@ impl_type_trait!(u16, Uint16);
impl_type_trait!(i16, Int16);
impl_type_trait!(i32, Int32);
impl_type_trait!(i64, Int64);
// impl_type_trait!(String, String);
// impl_type_trait!(bool, Bool);
// impl_type_trait!(f16, Float16);
impl_type_trait!(f64, Double);
Expand All @@ -433,6 +437,39 @@ impl_type_trait!(u64, Uint64);
// impl_type_trait!(, Complex128);
// impl_type_trait!(, Bfloat16);

/// Adapter for common Rust string types to Onnx strings.
///
/// It should be easy to use both `String` and `&str` as [TensorElementDataType::String] data, but
/// we can't define an automatic implementation for anything that implements `AsRef<str>` as it
/// would conflict with the implementations of [TypeToTensorElementDataType] for primitive numeric
/// types (which might implement `AsRef<str>` at some point in the future).
pub trait Utf8Data {
/// Returns the utf8 contents.
fn utf8_bytes(&self) -> &[u8];
}

impl Utf8Data for String {
fn utf8_bytes(&self) -> &[u8] {
self.as_bytes()
}
}

impl<'a> Utf8Data for &'a str {
fn utf8_bytes(&self) -> &[u8] {
self.as_bytes()
}
}

impl<T: Utf8Data> TypeToTensorElementDataType for T {
fn tensor_element_data_type() -> TensorElementDataType {
TensorElementDataType::String
}

fn try_utf8_bytes(&self) -> Option<&[u8]> {
Some(self.utf8_bytes())
}
}

/// Allocator type
#[derive(Debug, Clone)]
#[repr(i32)]
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,9 @@ impl<'a> Session<'a> {
// The C API expects pointers for the arrays (pointers to C-arrays)
let input_ort_tensors: Vec<OrtTensor<TIn, D>> = input_arrays
.into_iter()
.map(|input_array| OrtTensor::from_array(&self.memory_info, input_array))
.map(|input_array| {
OrtTensor::from_array(&self.memory_info, self.allocator_ptr, input_array)
})
.collect::<Result<Vec<OrtTensor<TIn, D>>>>()?;
let input_ort_values: Vec<*const sys::OrtValue> = input_ort_tensors
.iter()
Expand Down
152 changes: 126 additions & 26 deletions onnxruntime/src/tensor/ort_tensor.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
//! Module containing tensor with memory owned by Rust

use std::{fmt::Debug, ops::Deref};
use std::{ffi, fmt::Debug, ops::Deref};

use ndarray::Array;
use tracing::{debug, error};

use onnxruntime_sys as sys;

use crate::{
error::status_to_result, g_ort, memory::MemoryInfo, tensor::ndarray_tensor::NdArrayTensor,
OrtError, Result, TypeToTensorElementDataType,
error::call_ort, error::status_to_result, g_ort, memory::MemoryInfo,
tensor::ndarray_tensor::NdArrayTensor, OrtError, Result, TensorElementDataType,
TypeToTensorElementDataType,
};

/// Owned tensor, backed by an [`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html)
Expand Down Expand Up @@ -37,38 +38,103 @@ where
{
pub(crate) fn from_array<'m>(
memory_info: &'m MemoryInfo,
allocator_ptr: *mut sys::OrtAllocator,
mut array: Array<T, D>,
) -> Result<OrtTensor<'t, T, D>>
where
'm: 't, // 'm outlives 't
{
// where onnxruntime will write the tensor data to
let mut tensor_ptr: *mut sys::OrtValue = std::ptr::null_mut();
let tensor_ptr_ptr: *mut *mut sys::OrtValue = &mut tensor_ptr;
let tensor_values_ptr: *mut std::ffi::c_void = array.as_mut_ptr() as *mut std::ffi::c_void;
assert_ne!(tensor_values_ptr, std::ptr::null_mut());

let shape: Vec<i64> = array.shape().iter().map(|d: &usize| *d as i64).collect();
let shape_ptr: *const i64 = shape.as_ptr();
let shape_len = array.shape().len() as u64;

let status = unsafe {
g_ort().CreateTensorWithDataAsOrtValue.unwrap()(
memory_info.ptr,
tensor_values_ptr,
(array.len() * std::mem::size_of::<T>()) as u64,
shape_ptr,
shape_len,
T::tensor_element_data_type(),
tensor_ptr_ptr,
)
};
status_to_result(status).map_err(OrtError::CreateTensorWithData)?;
assert_ne!(tensor_ptr, std::ptr::null_mut());
match T::tensor_element_data_type() {
TensorElementDataType::Float
| TensorElementDataType::Uint8
| TensorElementDataType::Int8
| TensorElementDataType::Uint16
| TensorElementDataType::Int16
| TensorElementDataType::Int32
| TensorElementDataType::Int64
| TensorElementDataType::Double
| TensorElementDataType::Uint32
| TensorElementDataType::Uint64 => {
nbigaouette marked this conversation as resolved.
Show resolved Hide resolved
// primitive data is already suitably laid out in memory; provide it to
// onnxruntime as is
let tensor_values_ptr: *mut std::ffi::c_void =
array.as_mut_ptr() as *mut std::ffi::c_void;
assert_ne!(tensor_values_ptr, std::ptr::null_mut());

unsafe {
call_ort(|ort| {
ort.CreateTensorWithDataAsOrtValue.unwrap()(
memory_info.ptr,
tensor_values_ptr,
(array.len() * std::mem::size_of::<T>()) as u64,
shape_ptr,
shape_len,
T::tensor_element_data_type().into(),
tensor_ptr_ptr,
)
})
}
.map_err(OrtError::CreateTensorWithData)?;
assert_ne!(tensor_ptr, std::ptr::null_mut());

let mut is_tensor = 0;
let status = unsafe { g_ort().IsTensor.unwrap()(tensor_ptr, &mut is_tensor) };
status_to_result(status).map_err(OrtError::IsTensor)?;
}
TensorElementDataType::String => {
// create tensor without data -- data is filled in later
unsafe {
call_ort(|ort| {
ort.CreateTensorAsOrtValue.unwrap()(
allocator_ptr,
shape_ptr,
shape_len,
T::tensor_element_data_type().into(),
tensor_ptr_ptr,
)
})
}
.map_err(OrtError::CreateTensor)?;

let mut is_tensor = 0;
let status = unsafe { g_ort().IsTensor.unwrap()(tensor_ptr, &mut is_tensor) };
status_to_result(status).map_err(OrtError::IsTensor)?;
assert_eq!(is_tensor, 1);
// create null-terminated copies of each string, as per `FillStringTensor` docs
let null_terminated_copies: Vec<ffi::CString> = array
.iter()
.map(|elt| {
let slice = elt
.try_utf8_bytes()
marshallpierce marked this conversation as resolved.
Show resolved Hide resolved
.expect("String data type must provide utf8 bytes");
ffi::CString::new(slice)
})
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(OrtError::CStringNulError)?;

let string_pointers = null_terminated_copies
.iter()
.map(|cstring| cstring.as_ptr())
.collect::<Vec<_>>();

unsafe {
call_ort(|ort| {
ort.FillStringTensor.unwrap()(
tensor_ptr,
string_pointers.as_ptr(),
string_pointers.len() as u64,
)
})
}
.map_err(OrtError::FillStringTensor)?;
}
}

assert_ne!(tensor_ptr, std::ptr::null_mut());

Ok(OrtTensor {
c_ptr: tensor_ptr,
Expand Down Expand Up @@ -129,13 +195,14 @@ mod tests {
use super::*;
use crate::{AllocatorType, MemType};
use ndarray::{arr0, arr1, arr2, arr3};
use std::ptr;
use test_env_log::test;

#[test]
fn orttensor_from_array_0d_i32() {
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap();
let array = arr0::<i32>(123);
let tensor = OrtTensor::from_array(&memory_info, array).unwrap();
let tensor = OrtTensor::from_array(&memory_info, ptr::null_mut(), array).unwrap();
let expected_shape: &[usize] = &[];
assert_eq!(tensor.shape(), expected_shape);
}
Expand All @@ -144,7 +211,7 @@ mod tests {
fn orttensor_from_array_1d_i32() {
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap();
let array = arr1(&[1_i32, 2, 3, 4, 5, 6]);
let tensor = OrtTensor::from_array(&memory_info, array).unwrap();
let tensor = OrtTensor::from_array(&memory_info, ptr::null_mut(), array).unwrap();
let expected_shape: &[usize] = &[6];
assert_eq!(tensor.shape(), expected_shape);
}
Expand All @@ -153,7 +220,7 @@ mod tests {
fn orttensor_from_array_2d_i32() {
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap();
let array = arr2(&[[1_i32, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]]);
let tensor = OrtTensor::from_array(&memory_info, array).unwrap();
let tensor = OrtTensor::from_array(&memory_info, ptr::null_mut(), array).unwrap();
assert_eq!(tensor.shape(), &[2, 6]);
}

Expand All @@ -165,7 +232,40 @@ mod tests {
[[13, 14, 15, 16, 17, 18], [19, 20, 21, 22, 23, 24]],
[[25, 26, 27, 28, 29, 30], [31, 32, 33, 34, 35, 36]],
]);
let tensor = OrtTensor::from_array(&memory_info, array).unwrap();
let tensor = OrtTensor::from_array(&memory_info, ptr::null_mut(), array).unwrap();
assert_eq!(tensor.shape(), &[3, 2, 6]);
}

#[test]
fn orttensor_from_array_1d_string() {
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap();
let array = arr1(&[
String::from("foo"),
String::from("bar"),
String::from("baz"),
]);
let tensor = OrtTensor::from_array(&memory_info, ort_default_allocator(), array).unwrap();
assert_eq!(tensor.shape(), &[3]);
}

#[test]
fn orttensor_from_array_3d_string() {
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default).unwrap();
let array = arr3(&[
[["1", "2", "3"], ["4", "5", "6"]],
[["7", "8", "9"], ["10", "11", "12"]],
marshallpierce marked this conversation as resolved.
Show resolved Hide resolved
]);
let tensor = OrtTensor::from_array(&memory_info, ort_default_allocator(), array).unwrap();
assert_eq!(tensor.shape(), &[2, 2, 3]);
}

fn ort_default_allocator() -> *mut sys::OrtAllocator {
let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
unsafe {
// this default non-arena allocator doesn't need to be deallocated
call_ort(|ort| ort.GetAllocatorWithDefaultOptions.unwrap()(&mut allocator_ptr))
}
.unwrap();
allocator_ptr
}
}