Skip to content

Commit

Permalink
utils
Browse files Browse the repository at this point in the history
Signed-off-by: Justin Chu <justinchu@microsoft.com>
  • Loading branch information
justinchuby committed Aug 21, 2023
1 parent 632e9e1 commit b712886
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 33 deletions.
43 changes: 11 additions & 32 deletions onnx/defs/math/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,7 @@ ONNX_OPERATOR_SET_SCHEMA(
Softmax,
13,
OpSchema()
.FillUsing(SoftmaxFamilyDocGenerator(
.FillUsing(defs::math::utils::SoftmaxFamilyDocGenerator(
"Softmax",
"normalized exponential",
"Softmax(input, axis) = Exp(input) / ReduceSum(Exp(input), axis=axis, keepdims=1) "))
Expand Down Expand Up @@ -1119,7 +1119,7 @@ ONNX_OPERATOR_SET_SCHEMA(
LogSoftmax,
13,
OpSchema()
.FillUsing(SoftmaxFamilyDocGenerator(
.FillUsing(defs::math::utils::SoftmaxFamilyDocGenerator(
"LogSoftmax",
"log of softmax",
"LogSoftmax(input, axis) = Log(Softmax(input, axis=axis))"))
Expand Down Expand Up @@ -1163,7 +1163,7 @@ ONNX_OPERATOR_SET_SCHEMA(
ONNX_OPERATOR_SET_SCHEMA(
Hardmax,
13,
OpSchema().FillUsing(SoftmaxFamilyDocGenerator(
OpSchema().FillUsing(defs::math::utils::SoftmaxFamilyDocGenerator(
"Hardmax",
"hardmax",
"Hardmax(element in input, axis) = 1 if the element is the first maximum value along the specified axis, 0 otherwise")));
Expand Down Expand Up @@ -2921,27 +2921,6 @@ ONNX_OPERATOR_SET_SCHEMA(
}
}));

template <typename T>
static T get_scalar_value_from_tensor(const ONNX_NAMESPACE::TensorProto* t) {
if (t == nullptr) {
return T{};
}

auto data_type = t->data_type();
switch (data_type) {
case ONNX_NAMESPACE::TensorProto::FLOAT:
return static_cast<T>(ONNX_NAMESPACE::ParseData<float>(t).at(0));
case ONNX_NAMESPACE::TensorProto::DOUBLE:
return static_cast<T>(ONNX_NAMESPACE::ParseData<double>(t).at(0));
case ONNX_NAMESPACE::TensorProto::INT32:
return static_cast<T>(ONNX_NAMESPACE::ParseData<int32_t>(t).at(0));
case ONNX_NAMESPACE::TensorProto::INT64:
return static_cast<T>(ONNX_NAMESPACE::ParseData<int64_t>(t).at(0));
default:
fail_shape_inference("Unsupported input data type of ", data_type);
}
}

static const char* DFT_ver20_doc = R"DOC(Computes the discrete Fourier transform of the input.)DOC";

ONNX_OPERATOR_SET_SCHEMA(
Expand Down Expand Up @@ -3045,7 +3024,7 @@ ONNX_OPERATOR_SET_SCHEMA(
if (axis_tensor->dims_size() != 1) {
fail_shape_inference("axis input must be a scalar.");
}
const int64_t axis = get_scalar_value_from_tensor<int64_t>(axis_tensor);
const int64_t axis = defs::math::utils::GetScalarValueFromTensor<int64_t>(axis_tensor);
const int64_t rank = input_shape.dim_size();

if (!(-rank <= axis && axis < rank)) {
Expand All @@ -3067,7 +3046,7 @@ ONNX_OPERATOR_SET_SCHEMA(
if (dft_length->dims_size() != 0) {
fail_shape_inference("dft_length input must be a scalar.");
}
auto dft_length_value = get_scalar_value_from_tensor<int64_t>(dft_length);
auto dft_length_value = defs::math::utils::GetScalarValueFromTensor<int64_t>(dft_length);
result_shape_proto.mutable_dim(axis_idx)->set_dim_value(dft_length_value);
}

Expand Down Expand Up @@ -3151,7 +3130,7 @@ Generates a {name} window as described in the paper https://ieeexplore.ieee.org/
fail_shape_inference("size input must be a scalar.");
}

auto size_value = get_scalar_value_from_tensor<int64_t>(size);
auto size_value = defs::math::utils::GetScalarValueFromTensor<int64_t>(size);
if (size_value <= 0) {
fail_shape_inference("size input must be greater than 0.");
}
Expand Down Expand Up @@ -3381,12 +3360,12 @@ ONNX_OPERATOR_SET_SCHEMA(
if (num_mel_bins->dims_size() != 0) {
fail_shape_inference("num_mel_bins input must be scalar.");
}
num_mel_bins_value = get_scalar_value_from_tensor<int64_t>(num_mel_bins);
num_mel_bins_value = defs::math::utils::GetScalarValueFromTensor<int64_t>(num_mel_bins);

if (dft_length->dims_size() != 0) {
fail_shape_inference("dft_length input must be scalar.");
}
dft_length_value = get_scalar_value_from_tensor<int64_t>(dft_length);
dft_length_value = defs::math::utils::GetScalarValueFromTensor<int64_t>(dft_length);

if (num_mel_bins_value > 0 && dft_length_value > 0) {
ONNX_NAMESPACE::TensorShapeProto result_shape;
Expand Down Expand Up @@ -3498,7 +3477,7 @@ ONNX_OPERATOR_SET_SCHEMA(
if (nullptr == frame_step) {
return;
}
auto frame_step_value = get_scalar_value_from_tensor<int64_t>(frame_step);
auto frame_step_value = defs::math::utils::GetScalarValueFromTensor<int64_t>(frame_step);

// Determine the size of the DFT based on the 2 optional inputs window and frame_length.
// One must be set.
Expand Down Expand Up @@ -3528,7 +3507,7 @@ ONNX_OPERATOR_SET_SCHEMA(
if (frame_length->dims_size() != 0) {
fail_shape_inference("frame_length input must be scalar.");
}
auto frame_length_value = get_scalar_value_from_tensor<int64_t>(frame_length);
auto frame_length_value = defs::math::utils::GetScalarValueFromTensor<int64_t>(frame_length);

// Ensure that the window length and the dft_length match.
if (window_shape->dim_size() != 1) {
Expand Down Expand Up @@ -3559,7 +3538,7 @@ ONNX_OPERATOR_SET_SCHEMA(
if (frame_length->dims_size() != 0) {
fail_shape_inference("frame_length input must be scalar.");
}
dft_size = get_scalar_value_from_tensor<int64_t>(frame_length);
dft_size = defs::math::utils::GetScalarValueFromTensor<int64_t>(frame_length);
}

bool is_onesided = static_cast<bool>(getAttribute(ctx, "onesided", 0));
Expand Down
2 changes: 1 addition & 1 deletion onnx/defs/math/old.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3069,7 +3069,7 @@ ONNX_OPERATOR_SET_SCHEMA(
if (dft_length->dims_size() != 0) {
fail_shape_inference("dft_length input must be a scalar.");
}
auto dft_length_value = get_scalar_value_from_tensor<int64_t>(dft_length);
auto dft_length_value = defs::math::utils::GetScalarValueFromTensor<int64_t>(dft_length);
result_shape_proto.mutable_dim(axis_idx)->set_dim_value(dft_length_value);
}
// When DFT is onesided, the output shape is half the size of the input shape
Expand Down
25 changes: 25 additions & 0 deletions onnx/defs/math/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,30 @@

#include <string>

#include "onnx/defs/tensor_proto_util.h"

namespace ONNX_NAMESPACE {
namespace defs::math::utils {
template <typename T>
static T GetScalarValueFromTensor(const ONNX_NAMESPACE::TensorProto* t) {
if (t == nullptr) {
return T{};
}

auto data_type = t->data_type();
switch (data_type) {
case ONNX_NAMESPACE::TensorProto::FLOAT:
return static_cast<T>(ONNX_NAMESPACE::ParseData<float>(t).at(0));
case ONNX_NAMESPACE::TensorProto::DOUBLE:
return static_cast<T>(ONNX_NAMESPACE::ParseData<double>(t).at(0));
case ONNX_NAMESPACE::TensorProto::INT32:
return static_cast<T>(ONNX_NAMESPACE::ParseData<int32_t>(t).at(0));
case ONNX_NAMESPACE::TensorProto::INT64:
return static_cast<T>(ONNX_NAMESPACE::ParseData<int64_t>(t).at(0));
default:
fail_shape_inference("Unsupported input data type of ", data_type);
}
}
std::function<void(OpSchema&)>
SoftmaxFamilyDocGenerator(const char* name, const char* description, const char* equation) {
return [=](OpSchema& schema) {
Expand Down Expand Up @@ -69,4 +92,6 @@ from the back. Accepted range is [-r, r-1] where r = rank(input).
});
};
}
} // namespace defs::math::utils

} // namespace ONNX_NAMESPACE
5 changes: 5 additions & 0 deletions onnx/defs/math/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
#include "onnx/defs/schema.h"

namespace ONNX_NAMESPACE {
namespace defs::math::utils {
template <typename T>
static T GetScalarValueFromTensor(const ONNX_NAMESPACE::TensorProto* t);

std::function<void(OpSchema&)>
SoftmaxFamilyDocGenerator(const char* name, const char* description, const char* equation);
} // namespace defs::math::utils
} // namespace ONNX_NAMESPACE

0 comments on commit b712886

Please sign in to comment.